上一篇文章学习了kNN算法,它能完成许多分类任务,但是它最大的缺点就是无法给出数据的内部含义,而相比之下决策树的优势就是
- 数据形式非常容易理解
- 可以持久化分类器,而kNN则必须每次分类时重新学习一遍
决策树同样是一类常见的机器学习算法,使用它的目的是希望能从给定训练数据集学得一个模型用以对新示例进行分类。
决策树算法主流上有ID3、CART、C4.5等。
在讨论各个算法对样本集合的划分准则前,先带来一些定义:
信息熵
- 信息熵(information entropy)是度量样本集合纯度最常用的一种指标,假定当前样本集合D中第k类样本所占的比例为pk,则D的信息熵定义为:
Ent(D)=-∑p(xk)log(2,p(xk)) (k=1,2,..n)- 计算信息熵时约定if p=0 => p*log(2,p)=0
信息增益
- 在划分数据集之前之后信息发生的变化成为信息增益
- 假定离散属性a有V个可能的值{a1,a2,…,av},若使用a来对样本集D进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为av的样本,记为Dv,给不同分支结点赋予权重|Dv|/|D|,于是可计算属性a对样本集D进行划分所获得的“信息增益”
Gain(D,a)=Ent(D)-∑|Dv|/|D|*Ent(Dv)- 一般而言信息增益越大,意味着使用属性a来进行划分所得的纯度提升越大
- 对可取数值数目较多的属性有所偏好
增益率
- Gain_ratio(D,a)=Gain(D,a)/IV(a)
- 其中IV(a)=-∑(|Dv|/|D|*log2(|Dv|/|D|))称为a的“固有值”
- 对可取数值数目较少的属性有所偏好
基尼
- 数据集D的纯度可用基尼值来衡量:
Gini(D)=1-∑pk^2 (k=1,2,…,)
直观的说,Gini(D)反映了从D中随机抽取两个样本,其类别标记不一致的概率,因此,Gini(D)越小,数据集D的纯度越高- 属性a的基尼指数定义为
Gini_index(D,a)=∑(|Dv|/|D|*Gini(Dv))
一般选取划分后基尼指数最小的属性作为最优划分属性
于是刚刚提到的三种决策树算法选取的划分准则分别是:
ID3–信息增益,C4.5–增益率,CART–基尼指数
其实决策树一般还会涉及到剪枝处理,连续值处理,缺失值处理,多变量决策树等问题,这些问题将放在决策树(二)(三)中带来学习分析,今天的实战的重点是ID3,有关源码以及测试样本在github上
github链接地址
计算信息熵
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*log(prob,2)
return shannonEnt
说明:dataset为类似如下特征的数据
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
选取某特征值来划分数据集
#axis为所选取的用来划分的特征值的下标,value为划分所得数据集中该属性的取值
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
选择最好的划分属性
def chooseBestFeatureToSplit(dataSet):
#the number of feature attributes that the current data packet contains
numFeatures=len(dataSet[0])-1
# entropy 熵
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0
bestFeature=-1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals=set(featList)
newEntropy=0.0
for value in uniqueVals:
subDataset=splitDataSet(dataSet,i,value)
prob=len(subDataset)/float(len(dataSet))
newEntropy+=prob*calcShannonEnt(subDataset)
infoGain=baseEntropy-newEntropy
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
给定一个list,返回其中出现次数最多的元素
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
递归构造决策树(以字典形式表示)
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
#类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
#遍历完所有特征是返回出现次数最多的类别
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
mytree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:]
mytree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return mytree
如果觉得单纯的字典形式不够直观,还可以通过matplotlib将其画出
# coding=utf-8
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
import operator
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW
plotTree.yOff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
def getNumLeafs(myTree):
numleafs=0
firstStr=next(iter(myTree))
sencondDict=myTree[firstStr]
for key in sencondDict.keys():
if type(sencondDict[key]).__name__=='dict':
numleafs+=getNumLeafs(sencondDict[key])
else:
numleafs+=1
return numleafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=next(iter(myTree))
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=next(iter(myTree))
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
如何通过决策树分类
#自顶向下递归遍历
def classify(inputTree,featlabels,testVec):
firstStr=next(iter(inputTree))
secondDict=inputTree[firstStr]
featIndex=featlabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featlabels,testVec)
else:
classLabel=secondDict[key]
return classLabel
说明
featlabels格式形如:['no surfacing','flippers']
testVec格式形如: [0,1]
根据txt中的数据集的实战分类
def lensesClassify():
fr=open('lenses.txt','r')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree=createTree(lenses,lensesLabels)
tp.createPlot(lensesTree)