python实现基于信息增益的决策树归纳
作者:conggova 发布时间:2022-05-20 14:22:47
标签:python,信息增益,决策树
本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy
#加载训练数据
#文件格式:属性标号,是否连续【yes|no】,属性说明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)
#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)
#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)
root_attr_dict = {}
for line in attribute_file :
line = line.strip()
fld_list = line.split(',')
root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])
class_dict = {}
for line in class_desc_file :
line = line.strip()
fld_list = line.split(',')
class_dict[int(fld_list[0])] = fld_list[1]
trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
line = line.strip()
fld_list = line.split(',')
rec_id = int(fld_list[0])
a1 = int(fld_list[1])
a2 = int(fld_list[2])
a3 = float(fld_list[3])
c_id = int(fld_list[4])
if c_id not in class_member_set_dict :
class_member_set_dict[c_id] = set()
class_member_set_dict[c_id].add(rec_id)
trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)
attribute_file.close()
class_desc_file.close()
trainning_data_file.close()
class_possibility_dict = {}
for c_id in class_member_set_dict :
class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)
#等待分类的数据
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
line = line.strip()
fld_list = line.split(',')
rec_id = int(fld_list[0])
a1 = int(fld_list[1])
a2 = int(fld_list[2])
a3 = float(fld_list[3])
c_id = int(fld_list[4])
data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()
'''
决策树的表达
结点的需求:
1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点
2、保存分类所需信息
3、子结点列表
每个结点用Tuple类型表示
元素一是整形,取值123 分别对应两种分裂类型
元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号
元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点
对于3来说1代表属于判别集合
'''
#对于一个成员列表,计算其熵
#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和
def get_entropy( member_list ) :
#成员总数
mem_cnt = len(member_list)
#首先找出member中所包含的分类
class_dict = {}
for mem_id in member_list :
c_id = trainning_data_dict[mem_id][3]
if c_id not in class_dict :
class_dict[c_id] = set()
class_dict[c_id].add(mem_id)
tmp_sum = 0.0
for c_id in class_dict :
pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
tmp_sum += pi * mlab.log2(pi)
tmp_sum = -tmp_sum
return tmp_sum
def attribute_selection_method( member_list , attribute_dict ) :
#先计算原始的熵
info_D = get_entropy(member_list)
max_info_Gain = 0.0
attr_get = 0
split_point = 0.0
for attr_id in attribute_dict :
#对于每一个属性计算划分后的熵
#信息增益等于原始的熵减去划分后的熵
info_D_new = 0
#如果是连续属性
if attribute_dict[attr_id][0] == 'yes' :
#先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵
#找出其中最小的,作为此连续属性的划分点
value_list = []
for mem_id in member_list :
value_list.append(trainning_data_dict[mem_id][attr_id - 1])
#获取相邻元素的中值序列
mid_value_list = []
value_list.sort()
#print value_list
last_value = None
for value in value_list :
if value == last_value :
continue
if last_value is not None :
mid_value_list.append((last_value+value)/2)
last_value = value
#print mid_value_list
#对于中值序列做循环
#计算以此值做为划分点的熵
#总的熵等于两个划分的熵乘以两个划分的比重
min_info = 1000000000.0
total_mens = len(member_list) + 0.0
for mid_value in mid_value_list :
#小于mid_value的mem
less_list = []
#大于
more_list = []
for tmp_mem_id in member_list :
if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
less_list.append(tmp_mem_id)
else :
more_list.append(tmp_mem_id)
sum_info = len(less_list)/total_mens * get_entropy(less_list) \
+ len(more_list)/total_mens * get_entropy(more_list)
if sum_info < min_info :
min_info = sum_info
split_point = mid_value
info_D_new = min_info
#如果是离散属性
else :
#计算划分后的熵
#采用循环累加的方式
attr_value_member_dict = {} #键为attribute value , 值为memberlist
for tmp_mem_id in member_list :
attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
if attr_value not in attr_value_member_dict :
attr_value_member_dict[attr_value] = []
attr_value_member_dict[attr_value].append(tmp_mem_id)
#将每个离散值的熵乘以比重加到这上面
total_mens = len(member_list) + 0.0
sum_info = 0.0
for a_value in attr_value_member_dict :
sum_info += len(attr_value_member_dict[a_value])/total_mens \
* get_entropy(attr_value_member_dict[a_value])
info_D_new = sum_info
info_Gain = info_D - info_D_new
if info_Gain > max_info_Gain :
max_info_Gain = info_Gain
attr_get = attr_id
#如果是离散的
#print 'attr_get ' + str(attr_get)
if attribute_dict[attr_get][0] == 'no' :
return (1 , attr_get , split_point)
else :
return (2 , attr_get , split_point)
#第三类先不考虑
def get_decision_tree(father_node , key , member_list , attr_dict ) :
#最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键
#检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号
class_set = set()
for mem_id in member_list :
class_set.add(trainning_data_dict[mem_id][3])
if len(class_set) == 1 :
father_node[2][key] = (0 , (1 , class_set) , {} )
return
#检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号
#如果几个类的成员等量,则打印提示,并且全部添加到set里面
if not attr_dict :
class_cnt_dict = {}
for mem_id in member_list :
c_id = trainning_data_dict[mem_id][3]
if c_id not in class_cnt_dict :
class_cnt_dict[c_id] = 1
else :
class_cnt_dict[c_id] += 1
class_set = set()
max_cnt = 0
for c_id in class_cnt_dict :
if class_cnt_dict[c_id] > max_cnt :
max_cnt = class_cnt_dict[c_id]
class_set.clear()
class_set.add(c_id)
elif class_cnt_dict[c_id] == max_cnt :
class_set.add(c_id)
if len(class_set) > 1 :
print 'more than one class !'
father_node[2][key] = (0 , (1 , class_set ) , {} )
return
#找出最好的分区方案 , 暂不考虑第三种划分方法
#比较所有离散属性和所有连续属性的所有中值点划分的信息增益
split_criterion = attribute_selection_method(member_list , attr_dict)
#print split_criterion
selected_plan_id = split_criterion[0]
selected_attr_id = split_criterion[1]
#如果采用的是离散属性做为分区方案,删除这个属性
new_attr_dict = copy(attr_dict)
if attr_dict[selected_attr_id][0] == 'no' :
del new_attr_dict[selected_attr_id]
#建立一个结点new_node,father_node[2][key] = new_node
#然后对new node的每一个key , sub_member_list,
#调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
#实现递归
ele2 = ( selected_attr_id , set() )
#如果是1 , ele2保存所有离散值
if selected_plan_id == 1 :
for mem_id in member_list :
ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
#如果是2,ele2保存分裂点
elif selected_plan_id == 2 :
ele2[1].add(split_criterion[2])
#如果是3则保存判别集合,先不管
else :
print 'not completed'
pass
new_node = ( selected_plan_id , ele2 , {} )
father_node[2][key] = new_node
#生成KEY,并递归调用
if selected_plan_id == 1 :
#每个attr_value是一个key
attr_value_member_dict = {}
for mem_id in member_list :
attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
if attr_value not in attr_value_member_dict :
attr_value_member_dict[attr_value] = []
attr_value_member_dict[attr_value].append(mem_id)
for attr_value in attr_value_member_dict :
get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
pass
elif selected_plan_id == 2 :
#key 只有12 , 小于等于分裂点的是1 , 大于的是2
less_list = []
more_list = []
for mem_id in member_list :
attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
if attr_value <= split_criterion[2] :
less_list.append(mem_id)
else :
more_list.append(mem_id)
#if len(less_list) != 0 :
get_decision_tree(new_node , 1 , less_list , new_attr_dict)
#if len(more_list) != 0 :
get_decision_tree(new_node , 2 , more_list , new_attr_dict)
pass
#如果是3则保存判别集合,先不管
else :
print 'not completed'
pass
def get_class_sub(node , tp ) :
#
attr_id = node[1][0]
plan_id = node[0]
key = 0
if plan_id == 0 :
return node[1][1]
elif plan_id == 1 :
key = tp[attr_id - 1]
elif plan_id == 2 :
split_point = tuple(node[1][1])[0]
attr_value = tp[attr_id - 1]
if attr_value <= split_point :
key = 1
else :
key = 2
else :
print 'error'
return set()
return get_class_sub(node[2][key] , tp )
def get_class(r_node , tp) :
#tp为一组属性值
if r_node[0] != -1 :
print 'error'
return set()
if 1 in r_node[2] :
return get_class_sub(r_node[2][1] , tp)
else :
print 'error'
return set()
if __name__ == '__main__' :
root_node = ( -1 , set() , {} )
mem_list = trainning_data_dict.keys()
get_decision_tree(root_node , 1 , mem_list , root_attr_dict )
#测试分类器的准确率
diff_cnt = 0
for mem_id in data_to_classify_dict :
c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
print tuple(c_id)[0]
print data_to_classify_dict[mem_id][3]
print 'different'
diff_cnt += 1
print diff_cnt
来源:https://blog.csdn.net/conggova/article/details/77528966
0
投稿
猜你喜欢
- 图片缩放会失真是真理,在浏览器里也一样,貌似使用传说中的双三次插值可以让失真看起来比较不明显,但是真的想不通IE7已经实现了,却不默认打开,
- 在这篇文章中,我将努力揭开Mobile Web开发的神秘面纱,换句话说,也就是为了移动设备上的用户体验可以被接受,代码得怎么设计。我将阐述“
- 本文实例讲述了Python实现求数列和的方法。分享给大家供大家参考,具体如下:问题:输入输入数据有多组,每组占一行,由两个整数n(n<
- <% dim week_ymd(8) '测出可以手动设定日期,比如this_ymd=#2008-04-1
- 以图像处理见长的微软Live实验室,最近发布了一款新作:Pivot。装完启动后的第一印象就是一款浏览器,和IE、FF、Chrome又不太一样
- 几年前,看到一台湾人写的一段程序(好像是《日语基础》),在网页上实现音视频与文字的同步播放(就是音视频播到哪部分,相应的文字就亮显,点击某一
- 一、函数的变量作用域和可见性1.全局变量在main函数执行之前初始化,全局可见2.局部变量在函数内部或者if、for等语句块有效,使用之后外
- 1.nginx使用哪种网络协议? nginx是应用层 我觉得从下往上的话 传输层用的是tcp/ip 应用层用的是http fastcgi负责
- 软件环境: 1、操作系统:Windows 2000 Server 2、数 据 库:Oracle 8i R2 (8.1.7) for NT 企
- 2020年4月4日,是个特殊的日子,我们看到朋友圈很多灰化的图片.今天我们就聊聊图片灰度处理这事儿.PIL的基本概念:PIL中所涉及的基本概
- 新浪天气预报代码,需要的朋友可以复制下面的代码到要显示的页面,新浪代码 :<IFRAME WIDTH='260
- Protocol Buffers (类似XML的一种数据描述语言)最新版本2.3里,protoc—py_out命令只生成原生的P
- 在Oracle数据库中,DBA可以通过观测一定的表或视图来了解当前空间的使用状况,进而作出可能的调整决定。 一.表空间的自由空间 通过对表空
- 第一次用Python写这种比较实用且好玩的东西,权当练手吧游戏说明:* P键控制“暂停/开始”* 方向键控制贪吃蛇的方向源代码如下:from
- 这两天终于忍不住的去实验了一下,为什么网页的字体有时会显示成超级无敌难看的宋体呢?其实宋体不难看,难看的只是把它放在Leopard下,没有点
- 本文实例为大家分享了python多进程实现文件下载传输功能的具体代码,供大家参考,具体内容如下需求:实现文件夹拷贝功能(包括文件内的文件),
- 首先看一下这三个函数:rtrim() ltrim() trim();rtrim()定义以及用法: rtrim() 函数移除字符串右侧的空白字
- 导语:哈喽,哈喽~今天小编又来分享小游戏了——flappy bird(飞扬的小鸟),这个游戏非常的经
- 在上一篇博客中,已经将环境搭建好了。现在,我们利用搭建的环境来运行一条测试脚本,脚本中启动一个计算器的应用,并实现加法的运算。创建模拟器在运
- 如何做一个文本搜索? 比较简单,见下:<%Head = "搜索"SearchStri