# 实现决策树算法
# 2019-6-13
# by llllzy
from math import log
import trees # 自己定义的
myDat,labels = trees.createDataSet()
dataSet = myDat
# 3-1 定义一个函数计算数据集的熵
def calShannonEnt(dataSet):
# 获取数据集行数,定义空字典
row = len(dataSet)
labelCounts = {}
shannonEnt = 0.0
# 统计各标签的频次
for line in dataSet: # 逐行读取
label = line[-1] # 每行最右为标签值
labelCounts[label] = labelCounts.get(label, 0) + 1 # 对标签值计数
# 计算各标签的频率
for key in labelCounts:
prob = float(labelCounts[key]) / row
shannonEnt = shannonEnt - prob * log(prob, 2) # 根据熵的公式求和: H = -∑p(xi)*log2p(xi)
return shannonEnt
calShannonEnt(myDat)
# 3-2 按照给定特征划分数据集
# 目的:提取划分特征等于某个值时的去掉特征列的所有行
def splitDataSet(dataSet, axis, value): # axis划分数据集的特征列, value需要返回的特征的值
retDataSet = []
for line in dataSet:
if line[axis] == value:
reducedline = line[:axis] # 截取特征列前面部分
reducedline.extend(line[axis+1:]) # 添加特征列后面部分
retDataSet.append(reducedline)
return retDataSet
splitDataSet(myDat,0,1)
# 3-3 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
# 计算自变量个数
numFeatures = len(dataSet[0]) -1
# 计算初始信息熵
baseEntropy = calShannonEnt(dataSet)
# 计算最优信息增益值
bestInfoGain = 0.0
bestFeature = -1
# 创建唯一的分类标签列表
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) # 获取每一列去重的分类标签
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value) # 根据每一列的去重分类标签划分数据集
prob = len(subDataSet) / float(len(subDataSet)) # 计算划分后的数据集的比例
newEntropy += prob * calShannonEnt(subDataSet) # 计算划分后的熵
infoGain = baseEntropy - newEntropy
# 更新最大信息增益
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
chooseBestFeatureToSplit(myDat)
# 3-4 递归构建决策树
def createTree(dataSet, labels):
# 获取标签列
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 得到列表包含的所有属性值
bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最佳特征
bestFeatLabel = labels[bestFeat] # 获取最佳特征标签名
myTree = {bestFeatLabel:{}} # 创建决策树字典
del(labels[bestFeat]) # 删除最佳特征列
featValues = [example[bestFeat] for example in dataSet] # 取出最佳特征列
uniqueValues = set(featValues) # 最佳特征列的去重值
for value in uniqueValues:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
myTree = createTree(myDat, labels)
# 使用Matplotlib注解绘制树形图
import operator
# 定义一个可以根据多数投票表决的函数
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritem(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels