在上一篇文章基于信息增益的ID3决策树介绍。中介绍了基本的决策树概念和基于信息增益的ID3决策树的计算。这篇文章中介绍一下如何使用Python实现一个ID3决策树,其中主要的代码来自于机器学习实战一书中,本人对其做了一些改动,增加了一些内容。
决策树的伪代码。
决策树的生成可以使用一个递归来实现,在西瓜书中给出了决策树的伪代码:
输入:训练集$ D={(x_1,y_1),(x_2,y_2),…,(x_m,y_m)}$;
属性集 .
过程:函数TreeGenerate(D,A)
1:生成节点node;
2:if D 中样本全属于同一类别 C then
3: 将node标记为C类叶节点;return
4:end if
5:if A= ∅ OR D中样本在A上取值相同 then
6: 将node标记为叶节点,其类别标记为D中样本数最多的类;return
7:end if
8:从A中选择最优划分属性 ;
9:for 的每一个值 do
10: 为node生成一个分支;令 表示 中在 上取值为 的样本子集;
11: if 为空 then
12: 将分支节点标记为叶节点,其类别标记为 中样本最多的类;return
13:else
14: 以TreeGenerate($D_v,A $ \ )为分支节点。
15: end if
16:end for
输出:以node为根节点的一颗决策树
伪代码中结束递归的三个条件:
-
第2行:此时样本D中的样本全部都属于一种类别,比如都是好瓜,那么此时就说明不需要再划分了。
-
第5行:如果此时属性集合为空或者此时所有的样本的各个属性值都相同,比如剩了三个西瓜,这三个西瓜的根蒂、色泽、敲声都是一样的,这时候无法再根据属性进行划分了,所以在这些剩下的西瓜中找出数目最多的类别。
-
第12:如果数据集在某一个属性上没有样本,比如在经过多次划分,剩下的西瓜的色泽已经没有浅白这种瓜了,我们就让此刻浅白这种瓜的类别等于当前节点的父节点中样本数目最多的类别。
其中第14行的 \ ,表示集合减法,也就说在集合A中除去集合 的内容。
实现决策树。
-
首先我们需要获取到样本集:createDataSet。
关于样本集的创建,本人直接使用笨办法将所有的数据手动输入,封装在一个函数中,可以同时获取到数据和对应的类别,以及全部属性的全部可能性。
def createDataSet(): """ 创建测试的数据集 :return: """ dataSet = [ # 1 ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], # 2 ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], # 3 ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], # 4 ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], # 5 ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], # 6 ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'], # 7 ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'], # 8 ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'], # ---------------------------------------------------- # 9 ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'], # 10 ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'], # 11 ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'], # 12 ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'], # 13 ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'], # 14 ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'], # 15 ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'], # 16 ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'], # 17 ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'] ] # 特征值列表 labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感'] # 特征对应的所有可能的情况 labels_full = {} for i in range(len(labels)): labelList = [example[i] for example in dataSet] uniqueLabel = set(labelList) labels_full[labels[i]] = uniqueLabel return dataSet, labels, labels_full
其中可能对labels_full有一点疑惑,这个列表的输出为:
{'触感': {'软粘', '硬滑'}, '纹理': {'模糊', '稍糊', '清晰'}, '根蒂': {'蜷缩', '硬挺', '稍蜷'}, '色泽': {'浅白', '乌黑', '青绿'}, '敲击': {'沉闷', '清脆', '浊响'}, '脐部': {'稍凹', '凹陷', '平坦'}}
这个内容在最后对决策树进行补全的时候会使用到。
-
其次我们需要一个能够对数据集根据属性值进行划分的函数:splitDataSet。
我们如果选取了某个特征属性作为当前的划分属性,那么我们需要将该特征属性上的不同的值划分到不同的分支中,比如色泽这个属性,一共有三个属性值:乌黑、浅白、青绿,那么我们就会产生三个分支,每个分支对应一个属性值的样本集合,并且分支中还减去了该属性。
如图中,样本集A中所有的西瓜的色泽都为青绿,B中的色泽都为浅白,C中的色泽都为乌黑。并且由于已经使用过了色泽这个属性,所以在这三个分支中能够使用的属性需要去掉色泽。
def splitDataSet(dataSet, axis, value): """ 按照给定的特征值,将数据集划分 :param dataSet: 数据集 :param axis: 给定特征值的坐标 :param value: 给定特征值满足的条件,只有给定特征值等于这个value的时候才会返回 :return: """ # 创建一个新的列表,防止对原来的列表进行修改 retDataSet = [] # 遍历整个数据集 for featVec in dataSet: # 如果给定特征值等于想要的特征值 if featVec[axis] == value: # 将该特征值前面的内容保存起来 reducedFeatVec = featVec[:axis] # 将该特征值后面的内容保存起来,所以将给定特征值给去掉了 reducedFeatVec.extend(featVec[axis + 1:]) # 添加到返回列表中 retDataSet.append(reducedFeatVec) return retDataSet
比如,我们想要得到色泽这个属性上等于浅白的样本,那么就可以使用如下代码:
retDate = splitDataSet(myDat, 0, '浅白') # 在属性集中,下标0是色泽
可以得到色泽等于浅白的所有样本,并且除去了色泽属性:
[['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'], ['稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜']]
-
我们还需要一个能够统计样本集中样本数据最多的那个类别:majorityCnt。
还记得伪代码中的第二个结束条件么,也就是伪代码中的第5行。当划分到一定程度的时候,可能会出现剩下的样本集中所有的属性取值都相同的情况,或者此刻已经没有属性可以进行划分了,那么此时就需要进行分类了,分类的依据就是找出此刻样本集中数目最多的那个类别,作为该节点的类别。
比如:在西瓜样本集中一共有6个特征属性,我们根据这6个特征属性对样本集进行划分,不断的划分直到最后所有的特征属性都用完了,但是此刻样本的类别并不是全都一样的,那么我们就执行”少数服从多数“的办法,将多数类别设置为此时类别,如下图最后的样本集设置为好瓜。
再比如:样本集中的所有属性的值都一样,同样也是计算出数目最多的那个类别,假如我们原始的样本集只有三个西瓜,这三个瓜的根蒂、敲声等都一样,如下图也将其设置为好瓜:
def majorityCnt(classList): """ 找到次数最多的类别标签 :param classList: :return: """ # 用来统计标签的票数 classCount = collections.defaultdict(int) # 遍历所有的标签类别 for vote in classList: classCount[vote] += 1 # 从大到小排序 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回次数最多的标签 return sortedClassCount[0][0]
比如:
print(majorityCnt(['好瓜', '好瓜', '坏瓜'])) # 得到的就是好瓜
-
计算信息熵:calcShannonEnt。
伪代码中的第8行需要从属性集中选取最优化分属性,这个时候就需要使用到信息增益了,但是再计算信息增益之前我们还需要一个函数来计算信息熵,信息熵的计算按照公式来就可以。
def calcShannonEnt(dataSet): """ 计算给定数据集的信息熵(香农熵) :param dataSet: :return: """ # 计算出数据集的总数 numEntries = len(dataSet) # 用来统计标签 labelCounts = collections.defaultdict(int) # 循环整个数据集,得到数据的分类标签 for featVec in dataSet: # 得到当前的标签 currentLabel = featVec[-1] # # 如果当前的标签不再标签集中,就添加进去(书中的写法) # if currentLabel not in labelCounts.keys(): # labelCounts[currentLabel] = 0 # # # 标签集中的对应标签数目加一 # labelCounts[currentLabel] += 1 # 也可以写成如下 labelCounts[currentLabel] += 1 # 默认的信息熵 shannonEnt = 0.0 for key in labelCounts: # 计算出当前分类标签占总标签的比例数 prob = float(labelCounts[key]) / numEntries # 以2为底求对数 shannonEnt -= prob * log(prob, 2) return shannonEnt
比如使用该函数计算原始数据的信息熵,与我们之前计算得到的一致,需要保留小数即可:
myDat, labels, labels_full = createDataSet() # 得到数据 print(calcShannonEnt(myDat)) # 输出为:0.9975025463691153
-
计算剩下的各个属性的信息增益,选取信息增益最大的那个作为划分点:chooseBestFeatureToSplit。
决策树最重要的部分,计算各个属性的信息增益,最后返回信息增益最大的那个属性下标:
def chooseBestFeatureToSplit(dataSet, labels): """ 选择最好的数据集划分特征,根据信息增益值来计算 :param dataSet: :return: """ # 得到数据的特征值总数 numFeatures = len(dataSet[0]) - 1 # 计算出基础信息熵 baseEntropy = calcShannonEnt(dataSet) # 基础信息增益为0.0 bestInfoGain = 0.0 # 最好的特征值 bestFeature = -1 # 对每个特征值进行求信息熵 for i in range(numFeatures): # 得到数据集中所有的当前特征值列表 featList = [example[i] for example in dataSet] # 将当前特征唯一化,也就是说当前特征值中共有多少种 uniqueVals = set(featList) # 新的熵,代表当前特征值的熵 newEntropy = 0.0 # 遍历现在有的特征的可能性 for value in uniqueVals: # 在全部数据集的当前特征位置上,找到该特征值等于当前值的集合 subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value) # 计算出权重 prob = len(subDataSet) / float(len(dataSet)) # 计算出当前特征值的熵 newEntropy += prob * calcShannonEnt(subDataSet) # 计算出“信息增益” infoGain = baseEntropy - newEntropy # print('当前特征值为:' + labels[i] + ',对应的信息增益值为:' + str(infoGain)) # 如果当前的信息增益比原来的大 if infoGain > bestInfoGain: # 最好的信息增益 bestInfoGain = infoGain # 新的最好的用来划分的特征值 bestFeature = i # print('信息增益最大的特征为:' + labels[bestFeature]) return bestFeature
如果使用该函数对原始样本集进行计算,如果将该函数中的打印注释去掉的话会得到如下数据:
myDat, labels, labels_full = createDataSet() print(chooseBestFeatureToSplit(myDat, labels)) 输出如下: 当前特征值为:色泽,对应的信息增益值为:0.10812516526536531
当前特征值为:根蒂,对应的信息增益值为:0.14267495956679288
当前特征值为:敲击,对应的信息增益值为:0.14078143361499584
当前特征值为:纹理,对应的信息增益值为:0.3805918973682686
当前特征值为:脐部,对应的信息增益值为:0.28915878284167895
当前特征值为:触感,对应的信息增益值为:0.006046489176565584
信息增益最大的特征为:纹理
3
```
***
6. 判断样本集的各个属性是否完全一致:judgeEqualLabels。
在伪代码中,结束递归的条件其中有一项是:样本集在属性集合$A$中的各个取值完全一样,此时也是没有必要继续划分的,可以作为递归结束的条件,该函数会返回True或者False,如果是True说明数据集的所有属性的取值都一样,False则不然。
```
def judgeEqualLabels(dataSet):
"""
判断数据集的各个属性集是否完全一致
:param dataSet:
:return:
"""
# 计算出样本集中共有多少个属性,最后一个为类别
feature_leng = len(dataSet[0]) - 1
# 计算出共有多少个数据
data_leng = len(dataSet)
# 标记每个属性中第一个属性值是什么
first_feature = ''
# 各个属性集是否完全一致
is_equal = True
# 遍历全部属性
for i in range(feature_leng):
# 得到第一个样本的第i个属性
first_feature = dataSet[0][i]
# 与样本集中所有的数据进行对比,看看在该属性上是否都一致
for _ in range(1, data_leng):
# 如果发现不相等的,则直接返回False
if first_feature != dataSet[_][i]:
return False
return is_equal
```
-
使用上述的各个函数,创建决策树:createTree。
按照伪代码的思路,将上述各个函数拼装起来,这里的结束条件只有两种:第一个 if 对应的就是伪代码中的数据集 都属于一个类别的情况,第二个 if 对应的是伪代码中属性集 为空或者各个属性取值都一样的情况。此外还有一种情况并没有在这里给出,那就是如果当前属性的某个属性值上的样本集为空,应该是将其设置为父节点的所含样本最多的类别,但是这个在递归里面实现有一点难度,因为需要考虑到各个属性的各个值的情况,以及父节点类别的计算,比较复杂,所以在创建决策树的时候并不实现该步骤。
def createTree(dataSet, labels): """ 创建决策树 :param dataSet: 数据集 :param labels: 特征标签 :return: """ # 拿到所有数据集的分类标签 classList = [example[-1] for example in dataSet] # 统计第一个标签出现的次数,与总标签个数比较,如果相等则说明当前列表中全部都是一种标签,此时停止划分 if classList.count(classList[0]) == len(classList): return classList[0] # 计算第一行有多少个数据,如果只有一个的话说明所有的特征属性都遍历完了,剩下的一个就是类别标签,或者所有的样本在全部属性上都一致 if len(dataSet[0]) == 1 or judgeEqualLabels(dataSet): # 返回剩下标签中出现次数较多的那个 return majorityCnt(classList) # 选择最好的划分特征,得到该特征的下标 bestFeat = chooseBestFeatureToSplit(dataSet=dataSet, labels=labels) # 得到最好特征的名称 bestFeatLabel = labels[bestFeat] # 使用一个字典来存储树结构,分叉处为划分的特征名称 myTree = {bestFeatLabel: {}} # 将本次划分的特征值从列表中删除掉 del(labels[bestFeat]) # 得到当前特征标签的所有可能值 featValues = [example[bestFeat] for example in dataSet] # 唯一化,去掉重复的特征值 uniqueVals = set(featValues) # 遍历所有的特征值 for value in uniqueVals: # 得到剩下的特征标签 subLabels = labels[:] subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels) # 递归调用,将数据集中该特征等于当前特征值的所有数据划分到当前节点下,递归调用时需要先将当前的特征去除掉 myTree[bestFeatLabel][value] = subTree return myTree
如果使用该函数生成决策树,会得到如下一个字典:
{'纹理': {'稍糊': {'触感': {'硬滑': '坏瓜', '软粘': '好瓜'}}, '模糊': '坏瓜', '清晰': {'根蒂': {'蜷缩': '好瓜', '稍蜷': {'色泽': {'乌黑': {'触感': {'硬滑': '好瓜', '软粘': '坏瓜'}}, '青绿': '好瓜'}}, '硬挺': '坏瓜'}}}}
-
决策树绘制:treePlotter.createPlot(myTree)。
得到了决策树就需要将其绘制出来,有关决策树的绘制,直接使用了《机器学习实战》中的代码:
# @Time : 2017/12/18 19:46
# @Author : Leafage
# @File : treePlotter.py
# @Software: PyCharm
import matplotlib.pylab as plt
import matplotlib
# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']
# 分叉节点,也就是决策节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")
# 箭头样式
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制一个节点
:param nodeTxt: 描述该节点的文本信息
:param centerPt: 文本的坐标
:param parentPt: 点的坐标,这里也是指父节点的坐标
:param nodeType: 节点类型,分为叶子节点和决策节点
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
"""
获取叶节点的数目
:param myTree:
:return:
"""
# 统计叶子节点的总数
numLeafs = 0
# 得到当前第一个key,也就是根节点
firstStr = list(myTree.keys())[0]
# 得到第一个key对应的内容
secondDict = myTree[firstStr]
# 递归遍历叶子节点
for key in secondDict.keys():
# 如果key对应的是一个字典,就递归调用
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
# 不是的话,说明此时是一个叶子节点
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
"""
得到数的深度层数
:param myTree:
:return:
"""
# 用来保存最大层数
maxDepth = 0
# 得到根节点
firstStr = list(myTree.keys())[0]
# 得到key对应的内容
secondDic = myTree[firstStr]
# 遍历所有子节点
for key in secondDic.keys():
# 如果该节点是字典,就递归调用
if type(secondDic[key]).__name__ == 'dict':
# 子节点的深度加1
thisDepth = 1 + getTreeDepth(secondDic[key])
# 说明此时是叶子节点
else:
thisDepth = 1
# 替换最大层数
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
"""
计算出父节点和子节点的中间位置,填充信息
:param cntrPt: 子节点坐标
:param parentPt: 父节点坐标
:param txtString: 填充的文本信息
:return:
"""
# 计算x轴的中间位置
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
# 计算y轴的中间位置
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# 进行绘制
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制出树的所有节点,递归绘制
:param myTree: 树
:param parentPt: 父节点的坐标
:param nodeTxt: 节点的文本信息
:return:
"""
# 计算叶子节点数
numLeafs = getNumLeafs(myTree=myTree)
# 计算树的深度
depth = getTreeDepth(myTree=myTree)
# 得到根节点的信息内容
firstStr = list(myTree.keys())[0]
# 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 绘制该节点与父节点的联系
plotMidText(cntrPt, parentPt, nodeTxt)
# 绘制该节点
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 得到当前根节点对应的子树
secondDict = myTree[firstStr]
# 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# 循环遍历所有的key
for key in secondDict.keys():
# 如果当前的key是字典的话,代表还有子树,则递归遍历
if isinstance(secondDict[key], dict):
plotTree(secondDict[key], cntrPt, str(key))
else:
# 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
# 打开注释可以观察叶子节点的坐标变化
# print((plotTree.xOff, plotTree.yOff), secondDict[key])
# 绘制叶子节点
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# 绘制叶子节点和父节点的中间连线内容
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
"""
需要绘制的决策树
:param inTree: 决策树字典
:return:
"""
# 创建一个图像
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
# 计算出决策树的总宽度
plotTree.totalW = float(getNumLeafs(inTree))
# 计算出决策树的总深度
plotTree.totalD = float(getTreeDepth(inTree))
# 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
plotTree.xOff = -0.5/plotTree.totalW
# 初始的y轴偏移量,每次向下或者向上移动1/D
plotTree.yOff = 1.0
# 调用函数进行绘制节点图像
plotTree(inTree, (0.5, 1.0), '')
# 绘制
plt.show()
if __name__ == '__main__':
testTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
createPlot(testTree)
此刻绘制的决策树如下:
但是如图中箭头所指,在色泽的分支下并没有浅白这个选项,这是为什么呢?因为在给出的样本集中并不存在纹理为清晰,根蒂为稍蜷、色泽为浅白的瓜,这就属于伪代码中递归结束的第3中情况,也就是我们在创建决策树中并没有实现的那种结束递归的判断。此时我们是应该将色泽属性下浅白这个分支设置为其父节点中所含样本最多的类别的,也就是色泽中数目最多的类别,如果父节点中分类比例一致的话,那么就随机取其中一个。这部分的实现就需要使用到下面的补全决策树了。
8 . 补全决策树:makeTreeFull。
我们已经知道使用上面递归生成的决策树,可能会出现某个特征属性下为空集的情况,而此时并没有按照规定将其设置为父节点所含数目最多的类别,而是直接不计算该节点,造成了决策树的缺失,就如上面的色泽没有浅白一样,所以我们需要将其补全。
补全的思路就是,在创建数据集的时候就得到各个属性对应的全部属性值,也就是一开始创建数据集返回的labels_full,然后再我们创建完决策树之后,遍历一遍该决策树。计算出每一层所含数目最多的类别,并且如果发现当前层次出现了缺失的属性值就添加进去作为子节点,并且使用其父节点最多的类别作为其分类。
def makeTreeFull(myTree, labels_full, default):
"""
将树中的不存在的特征标签进行补全,补全为父节点中出现最多的类别
:param myTree: 生成的树
:param labels_full: 特征的全部标签
:param parentClass: 父节点中所含最多的类别
:param default: 如果缺失标签中父节点无法判断类别则使用该值
:return:
"""
# 这里所说的父节点就是当前根节点,把当前根节点下不存在的特征标签作为子节点
# 拿到当前的根节点
root_key = list(myTree.keys())[0]
# 拿到根节点下的所有分类,可能是子节点(好瓜or坏瓜)也可能不是子节点(再次划分的属性值)
sub_tree = myTree[root_key]
# 如果是叶子节点就结束
if isinstance(sub_tree, str):
return
# 找到使用当前节点分类下最多的种类,该分类结果作为新特征标签的分类,如:色泽下面没有浅白则用色泽中有的青绿分类作为浅白的分类
root_class = []
# 把已经分好类的结果记录下来
for sub_key in sub_tree.keys():
if isinstance(sub_tree[sub_key], str):
root_class.append(sub_tree[sub_key])
# 找到本层出现最多的类别,可能会出现相同的情况取其一
if len(root_class):
most_class = collections.Counter(root_class).most_common(1)[0][0]
else:
most_class = None# 当前节点下没有已经分类好的属性
# print(most_class)
# 循环遍历全部特征标签,将不存在标签添加进去
for label in labels_full[root_key]:
if label not in sub_tree.keys():
if most_class is not None:
sub_tree[label] = most_class
else:
sub_tree[label] = default
# 递归处理
for sub_key in sub_tree.keys():
if isinstance(sub_tree[sub_key], dict):
makeTreeFull(myTree=sub_tree[sub_key], labels_full=labels_full, default=default)
补全决策树之后再次绘图:
此时色泽中已经有浅白这个分支了,因为父节点色泽中只有一个青绿分支为好瓜,所以将其设置为好瓜。
此时与西瓜书中的决策树一致:
注意:书中的触感分支是错误的,在勘误中已经改正,改正之后就如我们绘制的一样,只不过节点的位置不一样罢了:
源代码:西瓜书决策树实现