决策树图形展示

绘图工具,基于python的Matplotlib,绘出决策树图形,以下代码亲测可用

从原始数据集中创建决策树,并使用python函数库绘制树形图。

#coding=UTF-8
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
	createPlot.ax1.annotate(nodeTxt,xy=parentPt,\
		xycoords='axes fraction',
		xytext=centerPt,textcoords='axes fraction',\
		va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

def createPlot():
	fig=plt.figure(1,facecolor='white')
	fig.clf()
	createPlot.ax1=plt.subplot(111,frameon=False)
	plotNode('decisive node',(0.5,0.1),(0.1,0.5),decisionNode)
	plotNode('leaf node ',(0.8,0.1),(0.3,0.8),leafNode)
	plt.show()
	
def getNumLeafs(myTree):
	numLeafs=0
	firstStr=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			numLeafs+=getNumLeafs(secondDict[key])
		else:numLeafs +=1
	return numLeafs
	
def getTreeDepth(myTree):
	maxDepth=0
	firstStr=myTree.keys()[0]
	secondDict=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			thisDepth=1+getTreeDepth(secondDict[key])
		else: thisDepth=1
		if thisDepth>maxDepth:maxDepth=thisDepth
	return maxDepth
	
def retriveTree(i):
	listofTrees=[{'no surfacing':{0:'no',1:{'flippers': \
				{0:'no',1:'yes'}}}},
				{'no surfacing':{0:'no',1:{'flippers':\
				  {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
				  ]
	return listofTrees[i]
	
def plotMidText(cntrPt,parentPt,txtString):
	xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
	yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
	createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
	numLeafs=getNumLeafs(myTree)
	deptg=getTreeDepth(myTree)
	firstStr=myTree.keys()[0]
	cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,\
					plotTree.yOff)
	plotMidText(cntrPt,parentPt,nodeTxt)
	plotNode(firstStr,cntrPt,parentPt,decisionNode)
	secondDict=myTree[firstStr]
	plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key],cntrPt,str(key))
		else:
			plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
			plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
					cntrPt,leafNode)
			plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
	plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

def createPlot(inTree):
	fig=plt.figure(1,facecolor='white')
	fig.clf()
	axprops=dict(xticks=[],yticks=[])
	createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
	plotTree.totalW=float(getNumLeafs(inTree))
	plotTree.totalD=float(getTreeDepth(inTree))
	plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
	plotTree(inTree,(0.5,1.0),'')
	plt.show()

在字典树中随意添加一些数据,观察图形变化






发布了192 篇原创文章 · 获赞 27 · 访问量 10万+

猜你喜欢

转载自blog.csdn.net/lovely_girl1126/article/details/79169322