Python机器学习之决策树算法
作者:自在逍遥 发布时间:2022-06-07 06:38:43
一、决策树原理
决策树是用样本的属性作为结点,用属性的取值作为分支的树结构。
决策树的根结点是所有样本中信息量最大的属性。树的中间结点是该结点为根的子树所包含的样本子集中信息量最大的属性。决策树的叶结点是样本的类别值。决策树是一种知识表示形式,它是对所有样本数据的高度概括决策树能准确地识别所有样本的类别,也能有效地识别新样本的类别。
决策树算法ID3的基本思想:
首先找出最有判别力的属性,把样例分成多个子集,每个子集又选择最有判别力的属性进行划分,一直进行到所有子集仅包含同一类型的数据为止。最后得到一棵决策树。
J.R.Quinlan的工作主要是引进了信息论中的信息增益,他将其称为信息增益(information gain),作为属性判别能力的度量,设计了构造决策树的递归算法。
举例子比较容易理解:
对于气候分类问题,属性为:
天气(A1) 取值为: 晴,多云,雨
气温(A2) 取值为: 冷 ,适中,热
湿度(A3) 取值为: 高 ,正常
风 (A4) 取值为: 有风, 无风
每个样例属于不同的类别,此例仅有两个类别,分别为P,N。P类和N类的样例分别称为正例和反例。将一些已知的正例和反例放在一起便得到训练集。
由ID3算法得出一棵正确分类训练集中每个样例的决策树,见下图。
决策树叶子为类别名,即P 或者N。其它结点由样例的属性组成,每个属性的不同取值对应一分枝。
若要对一样例分类,从树根开始进行测试,按属性的取值分枝向下进入下层结点,对该结点进行测试,过程一直进行到叶结点,样例被判为属于该叶结点所标记的类别。
现用图来判一个具体例子,
某天早晨气候描述为:
天气:多云
气温:冷
湿度:正常
风: 无风
它属于哪类气候呢?-------------从图中可判别该样例的类别为P类。
ID3就是要从表的训练集构造图这样的决策树。实际上,能正确分类训练集的决策树不止一棵。Quinlan的ID3算法能得出结点最少的决策树。
ID3算法:
1. 对当前例子集合,计算各属性的信息增益;
2. 选择信息增益最大的属性Ak;
3. 把在Ak处取值相同的例子归于同一子集,Ak取几个值就得几个子集;
4.对既含正例又含反例的子集,递归调用建树算法;
5. 若子集仅含正例或反例,对应分枝标上P或N,返回调用处。
一般只要涉及到树的情况,经常会要用到递归。
对于气候分类问题进行具体计算有:
1、 信息熵的计算: 其中S是样例的集合, P(ui)是类别i出现概率:
|S|表示例子集S的总数,|ui|表示类别ui的例子数。对9个正例和5个反例有:
P(u1)=9/14
P(u2)=5/14
H(S)=(9/14)log(14/9)+(5/14)log(14/5)=0.94bit
2、信息增益的计算:
其中A是属性,Value(A)是属性A取值的集合,v是A的某一属性值,Sv是S中A的值为v的样例集合,| Sv |为Sv中所含样例数。
以属性A1为例,根据信息增益的计算公式,属性A1的信息增益为
S=[9+,5-] //原样例集中共有14个样例,9个正例,5个反例
S晴=[2+,3-]//属性A1取值晴的样例共5个,2正,3反
S多云=[4+,0-] //属性A1取值多云的样例共4个,4正,0反
S雨=[3+,2-] //属性A1取值晴的样例共5个,3正,2反
故
3、结果为
属性A1的信息增益最大,所以被选为根结点。
4、建决策树的根和叶子
ID3算法将选择信息增益最大的属性天气作为树根,在14个例子中对天气的3个取值进行分枝,3 个分枝对应3 个子集,分别是:
其中S2中的例子全属于P类,因此对应分枝标记为P,其余两个子集既含有正例又含有反例,将递归调用建树算法。
5、递归建树
分别对S1和S3子集递归调用ID3算法,在每个子集中对各属性求信息增益.
(1)对S1,湿度属性信息增益最大,以它为该分枝的根结点,再向下分枝。湿度取高的例子全为N类,该分枝标记N。取值正常的例子全为P类,该分枝标记P。
(2)对S3,风属性信息增益最大,则以它为该分枝根结点。再向下分枝,风取有风时全为N类,该分枝标记N。取无风时全为P类,该分枝标记P。
二、PYTHON实现决策树算法分类
本代码为machine learning in action 第三章例子,亲测无误。
1、计算给定数据shangnon数据的函数:
def calcShannonEnt(dataSet):
#calculate the shannon value
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: #create the dictionary for all of the data
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*log(prob,2) #get the log value
return shannonEnt
2. 创建数据的函数
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1, 'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
3.划分数据集,按照给定的特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: #abstract the fature
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
4.选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i , value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy +=prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
5.递归创建树
用于找出出现次数最多的分类名称的函数
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
用于创建树的函数代码
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# the type is the same, so stop classify
if classList.count(classList[0]) == len(classList):
return classList[0]
# traversal all the features and choose the most frequent feature
if (len(dataSet[0]) == 1):
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
#get the list which attain the whole properties
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
然后是在python 名利提示符号输入如下命令:
myDat, labels = trees.createDataSet()
myTree = trees.createTree(myDat,labels)
print myTree
结果是:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
6.实用决策树进行分类的函数
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
在Python命令提示符,输入:
trees.classify(myTree,labels,[1,0])
得到结果:
'no'
Congratulation. Oh yeah. You did it.!!!
来源:http://blog.csdn.net/alvine008/article/details/37760639


猜你喜欢
- 1. 读取数据用pandas中的read_csv()函数读取出csv文件中的数据:import pandas as pddf = pd.re
- 什么是虚拟环境这是 Python 3.3 的新特性:https://www.python.org/dev/peps/pep-0405/假设自
- 超酷的js图片轮换/轮播 渐变效果··来自腾讯刚刚在腾讯女性频道上看到一个很酷的图片渐变轮换效果·····于是乎····抠下来了···分享·
- 1.进入Mysqld如果已经设置Mysql/Bin环境变量,直接在CMD里输入命令,如果没有设置Mysql环境变量,去Mysql安装目录的B
- 概述np.ones()函数返回给定形状和数据类型的新数组,其中元素的值设置为1。此函数与numpy zeros()函数非常相似。用法np.o
- 本文实例讲述了Python 面向对象之类class和对象基本用法。分享给大家供大家参考,具体如下:类(class):定义一件事物的抽象特点,
- 前言我们在往期对matplotlib.pyplot()方法学习,到现在我们已经会绘制折线图、柱状图、散点等常规的图表啦(往期的内容如下,大家
- 1) chocolatappChocolat是最新出现的一款强大的Mac系统文本编辑器,兼具原生的Cocoa及强大的文本编辑功能。Choco
- 【原文地址】My "First Look at Orcas" Presentation 【原文发表日期】 Th
- <script> Array.prototype.swap = function(i, j) { var temp = this
- 前言Laravel 队列为不同的后台队列服务提供统一的 API,例如 Beanstalk,Amazon SQS,Redis,甚至其他基于关系
- 需求:(1) 获取你对象chrome前一天的浏览记录中的所有网址(url)和访问时间,并存在一个txt文件中(2)将这个txt文件发送给指定
- 原来的程序是使用sqlite这个嵌入式数据库作为Remit(code name)的数据源的,因为NHibernate支持这个,然而有一点不好
- merge()import pandas as pdpd.merge(DateFrame1,DateFrame2,on = '
- 数据库快照是怎样工作的可以使用典型的数据库命令CREATE DATABASE语句来生成一个数据库快照,在声明中有一个源数据库快照的附加说明。
- '把pattern 又修改了下'code
- 刚接触python不久,编程也是三脚猫,所以对常用的这几个工具还没有一个好的使用习惯,毕竟程序语言是头顺毛驴。所以最近在工作中使用的时候在使
- JavaScript中,对象的extensible属性用于表示是否允许在对象中动态添加新的property。ECMAScript 3标准中,
- 我们平时接触的长乘法,按位相乘,是一种时间复杂度为 O(n ^ 2) 的算法。今天,我们来介绍一种时间复杂度为 O (n ^ log 3)
- 为啥要写这个脚本五一前的准备下班的时候,看到同事为了做数据库的某个表的数据字典,在做一个复杂的人工操作,就是一个字段一个字段的纯手撸,那速度