一、trees.py
from math import log import operator import treePlotter def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt def createDataSet(): dataSet = [ [1, 1,'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'] ] labels = ['no surfacing','flippers'] return dataSet, labels def splitDataSet(dataSet, axis, value): """划分数据集""" retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet def chooseBestFeatureToSplit(dataSet): """选择熵最大的特征""" numFeatures = len(dataSet[0]) - 1 baseEntroy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): featureList = [example[i] for example in dataSet] #print("featureList", featureList) uniqueVals = set(featureList) #print("uniqueVals", uniqueVals) newEntroy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntroy += prob * calcShannonEnt(subDataSet) infoGain = baseEntroy - newEntroy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] # print("classList", classList) # print("classList.count(classList[0])", classList.count(classList[0])) # print("len(classList)", len(classList)) #类别完全相同则停止继续划分 if classList.count(classList[0]) == len(classList): return classList[0] #遍历完所有特征时返回出现次数最多的类别 if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) #print("bestFeat", bestFeat) # print("bestFeatLabel", bestFeatLabel) # print("myTree in function", myTree) # print("labels", labels) featureValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featureValues) #print("uniqueVals", uniqueVals) for value in uniqueVals: #复制类标签 subLabels = labels[:] # print("subLabels", subLabels) #递归划分树 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) # print("myTree in for", myTree) return myTree def majorityCnt(classList): """返回出现次数最多的分类名称""" classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 #按第2个元素降序排序sortedClassCount [('B', 2), ('A', 1)] #sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) return sortedClassCount[0][0] def classify(inputTree, featLabels, testVec): """分类""" firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel def storeTree(inputTree, filename): import pickle fw = open(filename, 'wb') pickle.dump(inputTree, fw) fw.close() def grabTree(filename): import pickle fr = open(filename, 'rb') return pickle.load(fr) myDat, labels = createDataSet() print("myDat", myDat) print("labels", labels) print("shannonEnt", calcShannonEnt(myDat)) print("bestFeature", chooseBestFeatureToSplit(myDat)) #print("createTree", createTree(myDat, labels)) myTree = treePlotter.retrieveTree(0) print("myTree", myTree) print("classify[1,0]", classify(myTree, labels, [1,0])) print("classify[1,1]", classify(myTree, labels, [1,1])) storeTree(myTree, 'classifierStorage.txt') print("grabTree", grabTree('classifierStorage.txt')) fr = open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] lensesLabels = ['age','prescript','astigmatic','tearRate'] print("lenses", lenses) print("lensesLabels", lensesLabels) lensesTree = createTree(lenses, lensesLabels) print("lensesTree", lensesTree) treePlotter.createPlot(lensesTree)
-----------------------------------------------------------------------------------
二、treePlotter.py
'''
Created on Oct 14, 2010
@author: Peter Harrington
'''
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
#createPlot()
# myTree = retrieveTree(0)
# print("num of leafs", getNumLeafs(myTree))
# print("depth of tree", getTreeDepth(myTree))
# myTree['no surfacing'][3] = 'maybe'
# createPlot(myTree)