1.首先使用Matplotlib.pyplot模块中的annotate( )函数,使用其注释功能来画树的结点
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
leafNode = dict(boxstyle = "round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#使用dict()函数,构建结点和箭头的属性字典,将来用作annotate()函数中,参数bbox和参数arrowprops的值
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt,
textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#createPlot.ax1 是一个全局变量,已经在后面的createPlot函数中定义了绘图区
#构建函数plotNode( ),参数:
# nodeTxt :注释内容,是个字符串
# centerPt :注释的位置
# parentPt :被注释点的位置
# nodeType :注释结点的外形
#函数annotate()的参数:
# xy和xytext:分别代表点的位置和注释的位置
# xycoords和textcoords:分别表示对点和注释坐标的说明
# va和ha:分别代表水平和竖直方向的对齐
# bbox:矩形的属性字典
# arrowprops:箭头的属性字典
关于annotate( )函数的具体用法如下:
>>> help(pyplot.annotate)
Help on function annotate in module matplotlib.pyplot:annotate(*args, **kwargs)
call signature::
annotate(s, xy, xytext=None, xycoords='data',
textcoords='data', arrowprops=None, **kwargs)
Keyword arguments:
#关键词参数的描述如下:
Annotate the *x*, *y* point *xy* with text *s* at *x*, *y*
location *xytext*. (If *xytext* = *None*, defaults to *xy*,
and if *textcoords* = *None*, defaults to *xycoords*).
#给点xy在xytext位置处加文本内容s的注释(如果参数*xytext* = *None*,则其默认等于*xy*;如果参数*textcoords* = *None*,则其默认等于*xycoords*)
*arrowprops*, if not *None*, is a dictionary of line properties(see :class:`matplotlib.lines.Line2D`) for the arrow that connects
annotation to the point.
#参数*arrowprops*,如果不是设置成*None*,则是箭头的关于线条属性的字典,而箭头呢,用来连接点和注释。If the dictionary has a key *arrowstyle*, a FancyArrowPatch
instance is created with the given dictionary and is
drawn. Otherwise, a YAArow patch instance is created and
drawn.
#如果此字典有键*arrowstyle*,将会由此字典创建同时画出一个‘FancyArrowPatch’的实体。否则会生成一个YAArow的实体(FancyArrowPatch和YAArow是啥箭头,没懂 --#)
Valid keys for YAArow are
#关于YAArow属性字典的有效键及其描述如下:
========= =========================================================Key Description
========= =========================================================
width the width of the arrow in points
frac the fraction of the arrow length occupied by the head
headwidth the width of the base of the arrow head in points
shrink oftentimes it is convenient to have the arrowtip
and base a bit away from the text and point being
annotated. If *d* is the distance between the text and
annotated point, shrink will shorten the arrow so the tip
and base are shink percent of the distance *d* away from the
endpoints. ie, ``shrink=0.05 is 5%``
? any key for :class:`matplotlib.patches.polygon`
========= ========================================================
Valid keys for FancyArrowPatch are
#关于FancyArrowPatch属性字典的有效键及其描述如下:
=============== ===================================================Key Description
=============== ===================================================
arrowstyle the arrow style
connectionstyle the connection style
relpos default is (0.5, 0.5)
patchA default is bounding box of the text
patchB default is None
shrinkA default is 2 points
shrinkB default is 2 points
mutation_scale default is text size (in points)
mutation_aspect default is 1.
? any key for :class:`matplotlib.patches.PathPatch`
=============== ==================================================
*xycoords* and *textcoords* are strings that indicate the
coordinates of *xy* and *xytext*.
#参数*xycoords* 和 *textcoords*是字符串,分别用来说明参数*xy*和参数*xytext*的坐标
================= ==============================================Property Description
================= ==============================================
'figure points' points from the lower left corner of the figure
'figure pixels' pixels from the lower left corner of the figure
'figure fraction' 0,0 is lower left of figure and 1,1 is upper, right
'axes points' points from lower left corner of axes
'axes pixels' pixels from lower left corner of axes
'axes fraction' 0,1 is lower left of axes and 1,1 is upper right
'data' use the coordinate system of the object being
annotated (default)
'offset points' Specify an offset (in points) from the *xy* value
'polar' you can specify *theta*, *r* for the annotation,
even in cartesian plots. Note that if you
are using a polar axes, you do not need
to specify polar for the coordinate
system since that is the native "data" coordinate
system.
================= ============================================
If a 'points' or 'pixels' option is specified, values will be
added to the bottom-left and if negative, values will be
subtracted from the top-right. Eg::
#这里没太懂 --#
# 10 points to the right of the left border of the axes and# 5 points below the top border
xy=(10,-5), xycoords='axes points'
You may use an instance of
:class:`~matplotlib.transforms.Transform` or
:class:`~matplotlib.artist.Artist`. See
:ref:`plotting-guide-annotation` for more details.
The *annotation_clip* attribute contols the visibility of the
annotation when it goes outside the axes area. If True, the
annotation will only be drawn when the *xy* is inside the
axes. If False, the annotation will always be drawn regardless
of its position. The default is *None*, which behave as True
only if *xycoords* is"data".
Additional kwargs are Text properties:
#另外需要补充的关键字参数都是关于文本属性的,如下:
agg_filter: unknownalpha: float (0.0 transparent through 1.0 opaque)
animated: [True | False]
axes: an :class:`~matplotlib.axes.Axes` instance
backgroundcolor: any matplotlib color
bbox: rectangle prop dict
clip_box: a :class:`matplotlib.transforms.Bbox` instance
clip_on: [True | False]
clip_path: [ (:class:`~matplotlib.path.Path`, :class:`~matplotlib.
transforms.Transform`) | :class:`~matplotlib.patches.Patch` | None ]
color: any matplotlib color
contains: a callable function
family or fontfamily or fontname or name: [ FONTNAME | 'serif' | 'sans-ser
if' | 'cursive' | 'fantasy' | 'monospace' ]
figure: a :class:`matplotlib.figure.Figure` instance
fontproperties or font_properties: a :class:`matplotlib.font_manager.FontP
roperties` instance
gid: an id string
horizontalalignment or ha: [ 'center' | 'right' | 'left' ]
label: any string
linespacing: float (multiple of font size)
lod: [True | False]
multialignment: ['left' | 'right' | 'center' ]
path_effects: unknown
picker: [None|float|boolean|callable]
position: (x,y)
rasterized: [True | False | None]
rotation: [ angle in degrees | 'vertical' | 'horizontal' ]
rotation_mode: unknown
size or fontsize: [ size in points | 'xx-small' | 'x-small' | 'small' | 'm
edium' | 'large' | 'x-large' | 'xx-large' ]
snap: unknown
stretch or fontstretch: [ a numeric value in range 0-1000 | 'ultra-condens
ed' | 'extra-condensed' | 'condensed' | 'semi-condensed' | 'normal' | 'semi-expa
nded' | 'expanded' | 'extra-expanded' | 'ultra-expanded' ]
style or fontstyle: [ 'normal' | 'italic' | 'oblique']
text: string or anything printable with '%s' conversion.
transform: :class:`~matplotlib.transforms.Transform` instance
url: a url string
variant or fontvariant: [ 'normal' | 'small-caps' ]
verticalalignment or va or ma: [ 'center' | 'top' | 'bottom' | 'baseline'
]
visible: [True | False]
weight or fontweight: [ a numeric value in range 0-1000 | 'ultralight' | '
light' | 'normal' | 'regular' | 'book' | 'medium' | 'roman' | 'semibold' | 'demi
bold' | 'demi' | 'bold' | 'heavy' | 'extra bold' | 'black' ]
x: float
y: float
zorder: any number
.. plot:: mpl_examples/pylab_examples/annotation_demo2.py
2.为了确定x轴的长度,和y轴的高度,我们需要知道树的叶子结点数目和树的层数,下面定义两个函数getNumLeafs( )和getTreeDepth( )来求解树的叶子结点数目和树的层数:
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0] #获取树的根节点,既第一个划分特征
secondDict = myTree[firstStr]
for key in secondDict.keys(): #key是当前根结点的不同取值
if type(secondDict[key])._name_=='dict': #secondDict[key]代表子树(的根结点)
numLeafs += getNumLeafs(secondDict[key]) #如果子树是字典,则递归的调用getNumLeafs函数求解叶子结点
else: numLeafs +=1 #如果子树是叶子结点,则当前numLeafs+1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys(): #for循环遍历当前根结点的所有子树
if type(secondDict[key])._name_=='dict':
thisDepth = 1+getTreeDepth(secondDict[key]) #如果子树是字典,则递归的求解子树的层数
else: thisDepth =1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
3.有了以上的plotNode( )函数,getNumLeafs( )和getTreeDepth( )函数后,可以开始构造plotTree( )函数绘制决策树
构建函数retrieveTree( ),存储树信息,稍后用来测试代码:
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
] #字典嵌套在列表中,注意这里有列表中有两个字典,既有两棵树
return listOfTrees[i]
构建函数plotMidText( ),在父子结点间填充文本信息:
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString) #Figure对象调用text方法
构建绘图函数plotTree( ):
参考了一个大神的blog,明白了细节过程:
https://www.cnblogs.com/fantasy01/p/4595902.html
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree) #当前子树的叶子结点数目
depth = getTreeDepth(myTree) #当前子树的深度
firstStr = myTree.keys()[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #由当前结点的叶子结点,确定当前结点的位置(这个比较核心)
#plotTree.xOff:定义为最近一次绘制叶子结点的坐标(还未到当前,注意初始值)(这个的理解很关键)
#plotTree.yOff:当前绘制深度y坐标
#plotTree.totalW:整棵树的叶子结点树
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrpt,parentPt,decisionNode) #到此,当前带箭头的注释(也就是当前结点)绘制完
secondDict = myTree[firstStr]
plotTree.yOff=plot.yOff-1.0/plotTree.totalD #当前绘制完就递减一份
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #如果不是叶结点,就递归的去画
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #更新叶子结点的x坐标
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制叶子结点和一个箭头
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #for循环结束,代表这树的一层结束了,更新当前绘制深度
构建主函数createPlot( ):
def createPlot(inTree):
fig = plt.figure(1,facecolor = 'white') #创建一块背景为白色的区域
fig.clf() #清空绘图区域
axprops = dict(xticks=[],ytick=[]) #创建一个轴属性的字典,去掉坐标轴,下面作为sbuplot函数的参数
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
#关于函数中参数(**)的用法:https://www.cnblogs.com/empty16/p/6229538.html
plotTree.totalW = float(getNumLeafs(inTree)) #全局变量
plotTree.totalD = float(getTreeDepth(inTree)) #全局变量
plotTree.xOff =-0.5/plotTree.totalW; plotTree.yOff = 1.0 #全局变量初始值
plotTree(inTree,(0.5,1.0),'')
plt.show()