python实现ID3决策树算法
作者:杨柳岸晓风 发布时间:2023-04-13 09:35:28
标签:python,ID3,决策树
ID3决策树是以信息增益作为决策标准的一种贪心决策树算法
# -*- coding: utf-8 -*-
from numpy import *
import math
import copy
import cPickle as pickle
class ID3DTree(object):
def __init__(self): # 构造方法
self.tree = {} # 生成树
self.dataSet = [] # 数据集
self.labels = [] # 标签集
# 数据导入函数
def loadDataSet(self, path, labels):
recordList = []
fp = open(path, "rb") # 读取文件内容
content = fp.read()
fp.close()
rowList = content.splitlines() # 按行转换为一维表
recordList = [row.split("\t") for row in rowList if row.strip()] # strip()函数删除空格、Tab等
self.dataSet = recordList
self.labels = labels
# 执行决策树函数
def train(self):
labels = copy.deepcopy(self.labels)
self.tree = self.buildTree(self.dataSet, labels)
# 构件决策树:穿件决策树主程序
def buildTree(self, dataSet, lables):
cateList = [data[-1] for data in dataSet] # 抽取源数据集中的决策标签列
# 程序终止条件1:如果classList只有一种决策标签,停止划分,返回这个决策标签
if cateList.count(cateList[0]) == len(cateList):
return cateList[0]
# 程序终止条件2:如果数据集的第一个决策标签只有一个,返回这个标签
if len(dataSet[0]) == 1:
return self.maxCate(cateList)
# 核心部分
bestFeat = self.getBestFeat(dataSet) # 返回数据集的最优特征轴
bestFeatLabel = lables[bestFeat]
tree = {bestFeatLabel: {}}
del (lables[bestFeat])
# 抽取最优特征轴的列向量
uniqueVals = set([data[bestFeat] for data in dataSet]) # 去重
for value in uniqueVals: # 决策树递归生长
subLables = lables[:] # 将删除后的特征类别集建立子类别集
# 按最优特征列和值分隔数据集
splitDataset = self.splitDataSet(dataSet, bestFeat, value)
subTree = self.buildTree(splitDataset, subLables) # 构建子树
tree[bestFeatLabel][value] = subTree
return tree
# 计算出现次数最多的类别标签
def maxCate(self, cateList):
items = dict([(cateList.count(i), i) for i in cateList])
return items[max(items.keys())]
# 计算最优特征
def getBestFeat(self, dataSet):
# 计算特征向量维,其中最后一列用于类别标签
numFeatures = len(dataSet[0]) - 1 # 特征向量维数=行向量维数-1
baseEntropy = self.computeEntropy(dataSet) # 基础熵
bestInfoGain = 0.0 # 初始化最优的信息增益
bestFeature = -1 # 初始化最优的特征轴
# 外循环:遍历数据集各列,计算最优特征轴
# i为数据集列索引:取值范围0~(numFeatures-1)
for i in xrange(numFeatures):
uniqueVals = set([data[i] for data in dataSet]) # 去重
newEntropy = 0.0
for value in uniqueVals:
subDataSet = self.splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * self.computeEntropy(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain): # 信息增益大于0
bestInfoGain = infoGain # 用当前信息增益值替代之前的最优增益值
bestFeature = i # 重置最优特征为当前列
return bestFeature
# 计算信息熵
# @staticmethod
def computeEntropy(self, dataSet):
dataLen = float(len(dataSet))
cateList = [data[-1] for data in dataSet] # 从数据集中得到类别标签
# 得到类别为key、 出现次数value的字典
items = dict([(i, cateList.count(i)) for i in cateList])
infoEntropy = 0.0
for key in items: # 香农熵: = -p*log2(p) --infoEntropy = -prob * log(prob, 2)
prob = float(items[key]) / dataLen
infoEntropy -= prob * math.log(prob, 2)
return infoEntropy
# 划分数据集: 分割数据集; 删除特征轴所在的数据列,返回剩余的数据集
# dataSet : 数据集; axis: 特征轴; value: 特征轴的取值
def splitDataSet(self, dataSet, axis, value):
rtnList = []
for featVec in dataSet:
if featVec[axis] == value:
rFeatVec = featVec[:axis] # list操作:提取0~(axis-1)的元素
rFeatVec.extend(featVec[axis + 1:])
rtnList.append(rFeatVec)
return rtnList
# 存取树到文件
def storetree(self, inputTree, filename):
fw = open(filename,'w')
pickle.dump(inputTree, fw)
fw.close()
# 从文件抓取树
def grabTree(self, filename):
fr = open(filename)
return pickle.load(fr)
调用代码
# -*- coding: utf-8 -*-
from numpy import *
from ID3DTree import *
dtree = ID3DTree()
# ["age", "revenue", "student", "credit"]对应年龄、收入、学生、信誉4个特征
dtree.loadDataSet("dataset.dat", ["age", "revenue", "student", "credit"])
dtree.train()
dtree.storetree(dtree.tree, "data.tree")
mytree = dtree.grabTree("data.tree")
print mytree
来源:https://blog.csdn.net/yjIvan/article/details/71194383
0
投稿
猜你喜欢
- 一、LeetCode——125.验证回文串1.问题描述给定一个字符串,验证它是否是回文串,只考虑字母和数字字符,可以忽略字母的大小写。说明:
- 本文实例讲述了JavaScript求一组数的最小公倍数和最大公约数常用算法。分享给大家供大家参考,具体如下:方法来自求多个数最小公倍数的一种
- 生命游戏的算法就不多解释了,百度一下介绍随处可见。因为网上大多数版本都是基于pygame,matlab等外部库实现的,二维数组大多是用num
- 近来,打开微信群发消息,就会秒收到一些活跃分子的回复,有的时候感觉对方回答很在理,但是有的时候发现对方的回答其实是驴唇不对马嘴,仔细深究发现
- 本文实例讲述了Python装饰器。分享给大家供大家参考。具体分析如下:这是在Python学习小组上介绍的内容,现学现卖、多练习是好的学习方式
- 迷宫生成1.随机PRIM思路:先让迷宫中全都是墙,不断从列表(最初只含有一个启始单元格)中选取一个单元格标记为通路,将其周围(上下左右)未访
- 这是群里一朋友问的问题,当时我说判断下 day 是否相邻即可,后来细想,发现完全不对。问题需求给定5个相同格式的日期,怎么判断是否是连续5天
- 本文实例为大家分享了Python threading模块对单个接口进行并发测试的具体代码,供大家参考,具体内容如下本文知识点通过在threa
- Python项目中很多时候会需要将时间在Datetime格式和TimeStamp格式之间转化,又或者你需要将UTC时间转化为本地时间,本文总
- 一、什么是super1.super也是一个类,是的。他不是一个方法也不是一个内置的关键字。class A: pas
- 写了网址规范化后,尚奇公司的柳先生建议再深入讨论一下301转向/重定向。下面就谈谈我所了解的301转向在搜索引擎优化方面的应用。什么是301
- 其实网上已经有许多python语言书写的串口,但大部分都是python2写的,没有找到一个合适的python编写的串口助手,只能自己来写一个
- Django默认Path转换器str:匹配任何非空字符串,但不含斜杠/,如果你没有专门指定转换器,那么这个是默认使用的;int:匹配0和正整
- 本方法是基于文本密度的方法,最初的想法来源于哈工大的《基于行块分布函数的通用网页正文抽取算法》,本文基于此进行一些小修改。约定:
- 安装顺序rpm -ivhmysql-community-common-5.7.18-1.el7.x86_64.rpmmysql-commun
- 前言对于我这种英语比较差的人来说,无论是敲代码还是看文档,那都是离不开翻译软件的,于是我想自己用python做一个翻译软件,花了一个小时,终
- 场景:按照github文档上启动一个flask的app,默认是用5000端口,如果5000端口被占用,启动失败。样例代码:from flas
- 本文实例讲述了Python实现删除文件中含指定内容的行。分享给大家供大家参考,具体如下:#!/bin/env pythonimport sh
- 本文实例讲述了Python设计模式之工厂模式。分享给大家供大家参考,具体如下:工厂模式是一个在软件开发中用来创建对象的设计模式。工厂模式包涵
- 对于单页应用,官方提供了vue-router进行路由跳转的处理,本篇主要也是基于其官方文档写作而成。安装基于传统,我更喜欢采用npm包的形式