原理
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类
步骤分解
1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中
def chooseBestFeatureToSplit(dataSet): # 选择最好的分类特征
"""
:param dataSet: 原数据集
:return: 最好的划分特征的索引值
"""
numFeatures = len(dataSet[0]) - 1 # 获取特征数
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的信息熵
bestInfoGain = 0.0 # 初始化最好的信息熵
bestFeature = -1 # 初始化最好的用于分割的特征
for i in range(numFeatures):
# 创建唯一的分类标签列表
featList= [example[i] for example in dataSet] # 获取每个元素的第i个特征
uniqueVals = set(featList) # 数据特征去重 (此特征有几种情况)
newEntropy = 0.0
# 计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # probability,概率,可理解为权重
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大
# 计算最好的信息增益
if(infoGain > bestInfoGain): # 若新的信息增益大于之前的信息增益,则替换
bestInfoGain = infoGain
bestFeature = i # 表示最好的划分特征的索引值
return bestFeature
2.递归构建决策树
def createTree(dataSet, labels):
"""
:param dataSet: 数据集
:param labels: 标签列表, 包含了数据集中的所有特征的标签
:return:
"""
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选取的最好特征
bestFeatLabel = labels[bestFeat] # 最好特征标签
myTree = {bestFeatLabel:{ }} # 使用字典存储树信息
# 得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] # 因为下一步传参数时是引用传参
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
3.使用Matplotlib注解绘制树形图
import matplotlib.pyplot as plt
import trees
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 设置箭头的样式和背景色
arrow_args = dict(arrowstyle="<-") # 设置箭头的样式
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 设置背景色
fig.clf() # 清空画布
axprops = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框
plotTree.totalW = float(trees.getNumLeafs(inTree))
plotTree.totalD = float(trees.getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), ' ')
# plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode) # 第一个坐标是注解的坐标 第二个坐标是点的坐标
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt): #计算宽与高
numLeafs = trees.getNumLeafs(myTree)
depth = trees.getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # 根节点
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# 减少y偏移
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def retrieveTree(i): # 创建树
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
完整代码
trees.py
from math import log
import operator
import treePlotter
def calcShannonEnt(dataSet): # 计算给定数据集的香农熵
numEntries = len(dataSet)
labelCounts = {}
# 为所有可能的分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
# if currentLabel not in labelCounts.keys():
# labelCounts[currentLabel] = 0
# labelCounts[currentLabel] += 1
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
# 以2为底求对数
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集
"""
:param dataSet: 待划分的数据集
:param axis: 划分数据集的特征
:param value: 特征的返回值
:return:
"""
# 创建新的list对象
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: # 抽取
reducedFratVec = featVec[:axis]
reducedFratVec.extend(featVec[axis+1:])
retDataSet.append(reducedFratVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet): # 选择最好的分类特征
"""
:param dataSet: 原数据集
:return: 最好的划分特征的索引值
"""
numFeatures = len(dataSet[0]) - 1 # 获取特征数
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的信息熵
bestInfoGain = 0.0 # 初始化最好的信息熵
bestFeature = -1 # 初始化最好的用于分割的特征
for i in range(numFeatures):
# 创建唯一的分类标签列表
featList= [example[i] for example in dataSet] # 获取每个元素的第i个特征
uniqueVals = set(featList) # 数据特征去重 (此特征有几种情况)
newEntropy = 0.0
# 计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # probability,概率,可理解为权重
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大
# 计算最好的信息增益
if(infoGain > bestInfoGain): # 若新的信息增益大于之前的信息增益,则替换
bestInfoGain = infoGain
bestFeature = i # 表示最好的划分特征的索引值
return bestFeature
def majorityCnt(classList): # 多数表决决定叶子节点的分类
"""
:param classList: 类别列表
:return: 出现次数最多的分类名称
"""
classCount = {}
for vote in classList: # 统计分类列表中个类别出现的次数
# if vote not in classCount.keys(): classCount[vote] = 0
# classCount[vote] += 1
classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 根据出现次数排序
return sortedClassCount[0][0]
def createTree(dataSet, labels):
"""
:param dataSet: 数据集
:param labels: 标签列表, 包含了数据集中的所有特征的标签
:return:
"""
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选取的最好特征
bestFeatLabel = labels[bestFeat] # 最好特征标签
myTree = {bestFeatLabel:{ }} # 使用字典存储树信息
# 得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] # 因为下一步传参数时是引用传参
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 测试节点的数据类型是否为字典
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def createDataSet(): # 创建数据集
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def classify(inputTree, featLabels, testVec): # 分类器
"""
:param inputTree: 树,即数据集
:param featLabels: 特征标签
:param testVec: 待测向量
:return: 类别
"""
firstStr = list(inputTree.keys())[0]
# 将标签字符串转换为索引
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec) # 若未到叶子节点,则继续往下递归,直到叶子节点
else:
classLabel = secondDict[key] # 如果已到叶子节点, 则直接取dict当前key的value
return classLabel
def storeTree(inputTree, filename): # 序列化保存树(分类信息)
import pickle
fw = open(filename, 'wb+')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename): # 读取序列化文件
import pickle
fr = open(filename, "rb+")
return pickle.load(fr)
if __name__ == "__main__":
myDat, labels = createDataSet()
# myTree = createTree(myDat, labels)
# print(myTree)
print(myDat)
myTree = treePlotter.retrieveTree(0)
print(myTree)
print(classify(myTree, labels, [1, 0]))
print(classify(myTree, labels, [1, 1]))
print("===========store tree============")
storeTree(myTree, 'classifierStorafe.txt')
print(grabTree('classifierStorafe.txt'))
treePlotter
import matplotlib.pyplot as plt
import trees
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 设置箭头的样式和背景色
arrow_args = dict(arrowstyle="<-") # 设置箭头的样式
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 设置背景色
fig.clf() # 清空画布
axprops = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框
plotTree.totalW = float(trees.getNumLeafs(inTree))
plotTree.totalD = float(trees.getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), ' ')
# plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode) # 第一个坐标是注解的坐标 第二个坐标是点的坐标
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt): #计算宽与高
numLeafs = trees.getNumLeafs(myTree)
depth = trees.getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # 根节点
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# 减少y偏移
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def retrieveTree(i): # 创建树
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
if __name__ == "__main__":
# reTree = retrieveTree(1)
# leafs = trees.getNumLeafs(reTree)
# depth = trees.getTreeDepth(reTree)
# print(reTree)
# print(leafs)
# print(depth)
myTree = retrieveTree(0)
myTree['no surfacing'][3] = 'maybe'
createPlot(myTree)