之前以为treePlotter是一个待安装的库,后来总是安装不成功,在学习过程中发现它其实就是一系列函数组成的自定义模块,下面介绍该模块的代码以及怎么使用该模块。
第一步:新建一个python包,在__init__文件中键入以下代码:
# _*_ coding: UTF-8 _*_ import matplotlib.pyplot as plt """绘决策树的函数""" decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义分支点的样式 leafNode = dict(boxstyle="round4", fc="0.8") # 定义叶节点的样式 arrow_args = dict(arrowstyle="<-") # 定义箭头标识样式 # 计算树的叶子节点数量 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 计算树的最大深度 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 画出节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \ xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \ bbox=nodeType, arrowprops=arrow_args) # 标箭头上的文字 def plotMidText(cntrPt, parentPt, txtString): lens = len(txtString) xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002 yMid = (parentPt[1] + cntrPt[1]) / 2.0 createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] cntrPt = (plotTree.x0ff + \ (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW plotNode(secondDict[key], \ (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode) plotMidText((plotTree.x0ff, plotTree.y0ff) \ , cntrPt, str(key)) plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.x0ff = -0.5 / plotTree.totalW plotTree.y0ff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() if __name__=='__main__': createPlot()
参考廖雪峰教程,最后两行代码的意思是:当我们在命令行运行模块文件时,Python解释器把一个特殊变量
__name__
置为
__main__
,而如果在其他地方导入该模块时,
if
判断将失败,因此,这种
if
测试可以让一个模块通过命令行运行时执行一些额外的代码,最常见的就是运行测试。
第二步:关于如何导入treePlotter
这里我是直接将模块所在的文件夹放在运行程序文件夹下,这样只需要在待运行程序中直接import即可。
注意事项:
1、要让某个文件成为模块的话,在其目录下必须有一个__init__.py的文件
2、创建自己的模块时,要注意:
(1)模块名要遵循Python变量命名规范,不要使用中文、特殊字符;
(2)模块名不要和系统模块名冲突,最好先查看系统是否已存在该模块,检查方法是在Python交互环境执行import abc,若成功则说明系统存在此模块。
第三步:演示结果