def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]# firstStr = 'no surfacing’
secondDict = myTree[firstStr]#secondDict = myTree[firstStr] = {0: 'no', 1:{'flippers':{0: 'no', 1: 'yes'}
for key in secondDict.keys():#在除去第一个节点的树的键中遍历
if type(secondDict[key])== dict:#判断secondDict[key]是否是字典
numLeafs += getNumLeafs(secondDict[key])#是,就证明树这一支还没有结束
else:
numLeafs += 1#如果不是,叶节点加1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():]#secondDict = myTree[firstStr] = {0: 'no', 1:{'flippers':{0: 'no', 1: 'yes'}
if type(secondDict[key])== dict:
thisDepth = 1 + getTreeDepth(secondDict[key])#thisDepth没有赋值,所以不能用+=,根 所以起始为1
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
keys()用来遍历字典中的所有键
例如: for name in favorite_languages.keys(): print(name.title())
以树{‘no surfacing’: {0: ‘no’, 1:{‘flippers’:{0: ‘no’, 1: ‘yes’}}}}为 list[‘no surfacing’]
list(myTree.keys())[0] 应该是 ‘no surfacing’ firstStr = ‘no surfacing’
secondDict = myTree[firstStr] = {0: ‘no’, 1:{‘flippers’:{0: ‘no’, 1: ‘yes’}
list() 创建一个list列表
>>> import treePlotter
>>> from imp import reload
>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\Python\\treePlotter.py'>
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)#当前树的叶子数
depth = getTreeDepth(myTree)#没有用到这个变量
firstStr = list(myTree.keys())[0]
#cntrPt文本中心点 parentPt 指向文本中心的点
cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yoff)
plotMidText(cntrPt, parentPt, nodeTxt)#画分支上的键
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key])== dict:
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff), cntrPt, leafNode)
plotMidText((plotTree.xoff,plotTree.yoff), cntrPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
对之前的createPlot()函数做了改动
def createPlot(inTree):
#创建一个新图形
fig = plt.figure(1, facecolor='white')
#清空绘图区
fig.clf()
axprops = dict(xticks=[], yticks=[])#定义横纵坐标轴
#给全局变量createPlot.ax1赋值,绘制图像,无边框,无坐标轴
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))#全局变量宽度=叶子数
plotTree.totalD = float(getTreeDepth(inTree))#全局变量高度=深度
#但这样会使整个图形偏右因此初始的,将x值向左移一点。
plotTree.xoff = -0.5/plotTree.totalW;#图形的大小是0-1,0-1,例如绘制3个叶子结点,坐标应为1/3,2/3,3/3
plotTree.yoff = 1.0;
plotTree(inTree, (0.5,0.1), '')
plt.show() #显示最终绘制结果
不知道为啥最后会成了这样子,多了一条线,没找到问题在哪
>>> import treePlotter
>>> from imp import reload
>>> reload(treePlotter)
<module 'treePlotter' from 'E:\\Python\\treePlotter.py'>
>>> myTree = treePlotter.retrieveTree(0)
>>> treePlotter.createPlot(myTree)
>>> treePlotter.createPlot(myTree)
>>> myTree['no surfacing'][3]='maybe'
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
>>> treePlotter.createPlot(myTree)