解读python如何实现决策树算法
作者:laozhang 发布时间:2021-08-26 19:52:01
标签:python,决策树,算法
数据描述
每条数据项储存在列表中,最后一列储存结果
多条数据项形成数据集
data=[[d1,d2,d3...dn,result],
[d1,d2,d3...dn,result],
.
.
[d1,d2,d3...dn,result]]
决策树数据结构
class DecisionNode:
'''决策树节点
'''
def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
'''初始化决策树节点
args:
col -- 按数据集的col列划分数据集
value -- 以value作为划分col列的参照
result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
rb,fb -- 代表左右子树
'''
self.col=col
self.value=value
self.results=results
self.tb=tb
self.fb=fb
决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集
def pideset(rows,column,value):
'''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
返回两个数据集
'''
split_function=None
#value是数值类型
if isinstance(value,int) or isinstance(value,float):
#定义lambda函数当row[column]>=value时返回true
split_function=lambda row:row[column]>=value
#value是字符类型
else:
#定义lambda函数当row[column]==value时返回true
split_function=lambda row:row[column]==value
#将数据集拆分成两个
set1=[row for row in rows if split_function(row)]
set2=[row for row in rows if not split_function(row)]
#返回两个数据集
return (set1,set2)
def uniquecounts(rows):
'''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
'''
results={}
for row in rows:
r=row[len(row)-1]
if r not in results: results[r]=0
results[r]+=1
return results
def giniimpurity(rows):
'''返回rows数据集的基尼不纯度
'''
total=len(rows)
counts=uniquecounts(rows)
imp=0
for k1 in counts:
p1=float(counts[k1])/total
for k2 in counts:
if k1==k2: continue
p2=float(counts[k2])/total
imp+=p1*p2
return imp
def entropy(rows):
'''返回rows数据集的熵
'''
from math import log
log2=lambda x:log(x)/log(2)
results=uniquecounts(rows)
ent=0.0
for r in results.keys():
p=float(results[r])/len(rows)
ent=ent-p*log2(p)
return ent
def build_tree(rows,scoref=entropy):
'''构造决策树
'''
if len(rows)==0: return DecisionNode()
current_score=scoref(rows)
# 最佳信息增益
best_gain=0.0
#
best_criteria=None
#最佳划分
best_sets=None
column_count=len(rows[0])-1
#遍历数据集的列,确定分割顺序
for col in range(0,column_count):
column_values={}
# 构造字典
for row in rows:
column_values[row[col]]=1
for value in column_values.keys():
(set1,set2)=pideset(rows,col,value)
p=float(len(set1))/len(rows)
# 计算信息增益
gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
if gain>best_gain and len(set1)>0 and len(set2)>0:
best_gain=gain
best_criteria=(col,value)
best_sets=(set1,set2)
# 如果划分的两个数据集熵小于原数据集,进一步划分它们
if best_gain>0:
trueBranch=build_tree(best_sets[0])
falseBranch=build_tree(best_sets[1])
return DecisionNode(col=best_criteria[0],value=best_criteria[1],
tb=trueBranch,fb=falseBranch)
# 如果划分的两个数据集熵不小于原数据集,停止划分
else:
return DecisionNode(results=uniquecounts(rows))
def print_tree(tree,indent=''):
if tree.results!=None:
print(str(tree.results))
else:
print(str(tree.col)+':'+str(tree.value)+'? ')
print(indent+'T->',end='')
print_tree(tree.tb,indent+' ')
print(indent+'F->',end='')
print_tree(tree.fb,indent+' ')
def getwidth(tree):
if tree.tb==None and tree.fb==None: return 1
return getwidth(tree.tb)+getwidth(tree.fb)
def getdepth(tree):
if tree.tb==None and tree.fb==None: return 0
return max(getdepth(tree.tb),getdepth(tree.fb))+1
def drawtree(tree,jpeg='tree.jpg'):
w=getwidth(tree)*100
h=getdepth(tree)*100+120
img=Image.new('RGB',(w,h),(255,255,255))
draw=ImageDraw.Draw(img)
drawnode(draw,tree,w/2,20)
img.save(jpeg,'JPEG')
def drawnode(draw,tree,x,y):
if tree.results==None:
# Get the width of each branch
w1=getwidth(tree.fb)*100
w2=getwidth(tree.tb)*100
# Determine the total space required by this node
left=x-(w1+w2)/2
right=x+(w1+w2)/2
# Draw the condition string
draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
# Draw links to the branches
draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
# Draw the branch nodes
drawnode(draw,tree.fb,left+w1/2,y+100)
drawnode(draw,tree.tb,right-w2/2,y+100)
else:
txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
draw.text((x-20,y),txt,(0,0,0))
对测试数据进行分类(附带处理缺失数据)
def mdclassify(observation,tree):
'''对缺失数据进行分类
args:
observation -- 发生信息缺失的数据项
tree -- 训练完成的决策树
返回代表该分类的结果字典
'''
# 判断数据是否到达叶节点
if tree.results!=None:
# 已经到达叶节点,返回结果result
return tree.results
else:
# 对数据项的col列进行分析
v=observation[tree.col]
# 若col列数据缺失
if v==None:
#对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
# 分别以结果占总数比例计算得到左右子树的权重
tcount=sum(tr.values())
fcount=sum(fr.values())
tw=float(tcount)/(tcount+fcount)
fw=float(fcount)/(tcount+fcount)
result={}
# 计算左右子树的加权平均
for k,v in tr.items():
result[k]=v*tw
for k,v in fr.items():
# fr的结果k有可能并不在tr中,在result中初始化k
if k not in result:
result[k]=0
# fr的结果累加到result中
result[k]+=v*fw
return result
# col列没有缺失,继续沿决策树分类
else:
if isinstance(v,int) or isinstance(v,float):
if v>=tree.value: branch=tree.tb
else: branch=tree.fb
else:
if v==tree.value: branch=tree.tb
else: branch=tree.fb
return mdclassify(observation,branch)
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))
决策树剪枝
def prune(tree,mingain):
'''对决策树进行剪枝
args:
tree -- 决策树
mingain -- 最小信息增益
返回
'''
# 修剪非叶节点
if tree.tb.results==None:
prune(tree.tb,mingain)
if tree.fb.results==None:
prune(tree.fb,mingain)
#合并两个叶子节点
if tree.tb.results!=None and tree.fb.results!=None:
tb,fb=[],[]
for v,c in tree.tb.results.items():
tb+=[[v]]*c
for v,c in tree.fb.results.items():
fb+=[[v]]*c
#计算熵减少情况
delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
#熵的增加量小于mingain,可以合并分支
if delta<mingain:
tree.tb,tree.fb=None,None
tree.results=uniquecounts(tb+fb)
0
投稿
猜你喜欢
- 这里主要是解决multipart/form-data这种格式的文件上传,基本现在http协议上传文件基本上都是通过这种格式上传1 思路一般情
- JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,它基于ECMAScript的一个子集。 JSON
- 我们以用户查询语句为 https://www.aspxhome.com/chunfeng.asp为例来查询用户资料将从数据库Contact1
- 1. lr_scheduler相关lr_scheduler = WarmupLinearSchedule(optimizer, warmup
- Data Points Archive 有时, 为了让应用程序运行得更快,所做的全部工作就是在这里或那里做一些很小调整。啊,但关键在于确定如
- 目前,我们要在网页中使用圆角效果,总是通过切图然后嵌套很多div,用背景来实现圆角效果。对于前端开发工程师来说,圆角的确是一个让人又爱又恨的
- this指向当前作用域的对象,如果找不到,往上一层找,直到window。this 关键字很好用,很灵活,正因为很灵活,所以一不小心你就会掉进
- 如果你使用过大部分,那么你的ASP功力应该是非常高的了ADO对象(太常用了):ConnectionCommandRecordSetRecor
- 简介 本次项目登录注册验证是对之前学习知识点的加深学习,这次项目的练习的知识点有函数、判断语句、循环语句、文件操作等。项目流程 运行代码之后
- 如果是windows安装完成后,需要将'\Python27\Scripts\'加入系统环境变量# coding=utf-8i
- 本文实例讲述了Django开发的简易留言板。分享给大家供大家参考,具体如下:Django在线留言板小练习环境ubuntu16.04 + py
- 对所有数据进行整合与管理当你使用SQL Server 2008企业级的数据仓库平台时,你可以高效的操纵所有数据,并对其进行统一管理存储。◆合
- 首先为什么会有axis这个概念?因为在numpy模块中,大多数处理的是矩阵或者多维数组,同时,对多维数组或者矩阵的操作有多种可能,为了帮助实
- 一切皆是对象在 Python 一切皆是对象,包括所有类型的常量与变量,整型,布尔型,甚至函数。 参见stackoverflow上的一个问题
- 格式:Download.asp?FileName=要下载的文件名 代码如下:Dim Stream Dim Co
- 一、张量定义张量:TensorFlow的张量是n维数组,类型为tf.Tensor。标量:一个数字 (0阶张量)向量:一维数组 (1阶张量)矩
- ajax缓存和编码问题不难解决,下面是解决方法。编码问题默认使用UTF-8,如果一旦发现对象找不到的情况,可能js中输入了中文,同时js的编
- 大多数网站维护都采用“多人协作,共同管理”方式。某个人负责一个(或者多个)栏目,他只能对他负责的栏目进
- 函数原型:getopt.getopt(args, shortopts, longopts=[])参数解释:  
- 如下所示:import matplotlib.pyplot as pltimport numpy as npx = [11422,11360