一、决策树模型
1.1 定义
分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点和叶结点。内部结点表示一种特征或属性,叶结点表示一个类。
下图是一个决策树模型,圆和方框分别表示内部结点和叶结点。
1.2 决策树学习
二、特征选择
特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益或信息增益比。
2.1 熵的定义
2.2 条件熵
2.3 信息增益
信息增益(information gain)表示得知特征X的信息而使得类Y的信息的不确定性减少的程度。
2.4 信息增益比
信息增益值的大小是相对于训练数据集而言的,在分类问题中,训练数据集的经验熵大的时候,信息增益值就会偏大,反之,信息增益值会偏小。也就是说,以信息增益为划分训练数据集的特征,存在偏向于选择取值较多的特征的问题。而信息增益比则能解决这一问题。
三、ID3算法(Interative Dichotomiser 3,迭代二叉树3代)
3.1 代码实现
# -*- coding: utf-8 -*- """ Created on Fri Apr 13 18:50:19 2018 file name:tree.py @author: lizihua """ from math import log import operator #输入一个数据集 def createDataSet(): dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels =['no surfacing','flippers'] return dataSet, labels #计算给定数据集的熵 def calEntropy(dataSet): numentries = len(dataSet) labelCounts = {} #用字典记录给定数据集中各个类出现的次数 for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 #计算熵entropy entropy = 0.0 for key in labelCounts: #选择该类的概率 prob = float(labelCounts[key])/numentries entropy -= prob*log(prob,2) return entropy #按照给定特征划分数据集 #dataSet:待划分的数据集、axis:划分数据集的特征、value:需要返回的特征的值 #假定dataSet有n组数据,有m个特征 def splitDataSet(dataSet, axis, value): retDataSet = [] #featVec是1*m数组 for featVec in dataSet: if featVec[axis] == value: reduceFeatVec = featVec[:axis] reduceFeatVec.extend(featVec[axis+1:]) #reduceFeatVec是一个1*(m-1)列表,剔除了featVec[axis]这个特征 retDataSet.append(reduceFeatVec) return retDataSet #选择最好的数据集划分方式 #信息增益准则:对训练数据集DataSet,计算其每个特征的信息增益,并比较大小,选择信息增益最大的特征 #信息增益g(dataSet,Feature)=H(dataSet)-H(dataSET|Feature) def chooseBestFeatureToSplit(dataSet): #特征数量 numFeatures = len(dataSet[0])-1 #baseEntropy即H(dataSet) baseEntropy = calEntropy(dataSet) #bestInfoGain和infoGain都是g,前者是g的最大值 bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] #求这个特征的唯一分类结果,例如;该特征是年龄,其uniqueVals(类别)有:青年、中年、老年三种 uniqueVals = set(featList) #计算其每个特征的经验条件熵newEntropy即H(dataSET|Feature) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calEntropy(subDataSet) #计算其每个特征的信息增益 infoGain = baseEntropy - newEntropy #找到最大的信息增益的特征 if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature #采用多数表决的方法决定该叶子节点的分类 #与knn中的投票表决代码类似 def majorityCnt(classList): #创建字典(key是类,value是该类的次数), #然后按照value的值从大到小排序,最后返回value最大的对应的类(key值) classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] #递归构建决策树 #输入两个参数:数据集和标签列表(包含数据集中所有特征的标签) def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] #递归函数有两个终止条件: #1.所有类标签完全相同时,返回该类标签 if classList.count(classList[0]) == len(classList): return classList[0] #2.使用完所有特征后,仍不能将数据集划分成仅包含唯一类别的分组 #当遍历完所欲特征时,dataSet[0]==1,即dataSet只剩一列,且该列是分类标签 #此时,返回出现次数最多的类别 if len(dataSet[0]) == 1: return majorityCnt(classList) #选择根节点bestFeat,返回的是列索引 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] #利用字典变量myTree存储树的所有信息 myTree = {bestFeatLabel:{}} del(labels[bestFeat]) #获得根节点bestFeat所在列的值 featValues = [example[bestFeat] for example in dataSet] #获得根节点bestFeat所在列的值的集合 uniqueVals =set(featValues) #递归创建决策树 for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value),subLabels) return myTree #使用决策树的分类函数 def classify(inputTree,featLabels,testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] #将标签字符串转换为索引 featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabel #使用pickle模块存储决策树 def storeTree(inputTree,filename): import pickle fw = open(filename,'wb') pickle.dump(inputTree,fw) fw.close() #使用pickle模块读取上面生成的文件 def grabTree(filename): import pickle fr = open(filename,'rb') return pickle.load(fr)
测试代码:
if __name__ == "__main__": myData,labels=createDataSet() print(myData) #[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] print(labels) #['no surfacing', 'flippers'] print(calEntropy(myData)) #0.9709505944546686 """ #分类越多,熵越大 myData[0][-1] = 'maybe' print(myData) #[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] print(calEntropy(myData)) #1.3709505944546687 print(splitDataSet(myData,0,1)) #[[1, 'yes'], [1, 'yes'], [0, 'no']] print(chooseBestFeatureToSplit(myData)) #0,表示第0 个特征是最好的用于划分数据集的特征 myTree=createTree(myData,labels) print(myTree) #{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} """ myTree={'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} print(classify(myTree,labels,[1,0])) #no print(classify(myTree,labels,[1,1])) #yes storeTree(myTree,'classifierStorage.txt') print(grabTree('classifierStorage.txt')) #result:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} #读取隐形眼镜数据 fr=open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] lensesLabels =['ages','prescript','astigmatic','tearRate'] lensesTree = createTree(lenses,lensesLabels) #以字典形式输出隐形眼镜分类决策树 print(lensesTree) """lensesTree: result:{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'myope': 'hard', 'hyper': {'ages': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}}}}, 'no': {'ages': {'pre': 'soft', 'young': 'soft', 'presbyopic': {'prescript': {'myope': 'no lenses', 'hyper': 'soft'}}}}}}, 'reduced': 'no lenses'}} """
3.2 使用matplotlib注解绘制树形图
代码实现:
# -*- coding: utf-8 -*- """ Created on Sun Apr 15 18:41:40 2018 file name : treePlot.py @author: lizihua """ import matplotlib.pyplot as plt from tree import createTree #使用matplotlib的注释功能绘制树形图 #用文本注解绘制树节点 #定义文本框和箭头格式 decisionNode = dict(boxstyle="sawtooth",fc="0.8") leafNode = dict(boxstyle="round4",fc="0.8") arrow_args = dict(arrowstyle="<-") #绘制带箭头的注解 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center",ha="center",bbox=nodeType,arrowprops=arrow_args) #构造注解树 #获取叶节点的数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): #测试节点的数据是否是字典 if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): #测试节点的数据是否是字典 if type(secondDict[key]).__name__=='dict': thisDepth = 1+getTreeDepth(secondDict[key]) else: thisDepth =1 if thisDepth >maxDepth: maxDepth = thisDepth return maxDepth #在父子节点间填充文本信息 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 +cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 +cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString) def plotTree(myTree, parentPt, nodeTxt): #计算宽和高 numLeafs =getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr =list(myTree.keys())[0] cntrPt = (plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #标记子节点属性值 plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict =myTree[firstStr] #减少y的偏移 plotTree.yOff = plotTree.yOff -1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff= plotTree.xOff +1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD def createPlot(inTree): fig = plt.figure(1,facecolor='white') fig.clf() axprops = dict(xticks = [],yticks=[]) createPlot.ax1 = plt.subplot(111,frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree,(0.5,1.0),'') plt.show() #输出预先存储的树信息 def retrieveTree(i): listOfTree=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}}, {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no','1':'yes'}},1:'no'}}}}] return listOfTree[i]
测试代码1:
if __name__ == "__main__": print(retrieveTree(1)) #result:{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', '1': 'yes'}}, 1: 'no'}}}} myTree = retrieveTree(0) print(getNumLeafs(myTree)) #3 print(getTreeDepth(myTree)) #2 createPlot(myTree)
测试结果1:
测试代码2(隐形眼镜数据):
if __name__ == "__main__": fr=open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] lensesLabels =['ages','prescript','astigmatic','tearRate'] lensesTree = createTree(lenses,lensesLabels) print(lensesTree) createPlot(lensesTree)
测试结果2:
四、决策树的剪枝
一种简单的决策树剪枝方法: