决策树——机器学习实战完整版(python 3)

import matplotlib.pyplot as plt
# boxstyle是文本框类型 fc是边框粗细 sawtooth是锯齿形
'''xy是终点坐标
xytext是起点坐标
可能疑问:为什么说是终点,但是却是箭头从这出发的?
解答:arrowstyle="<-" 看到没有,这是个反向的箭头'''
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

#createPlot 主函数,调用即可画出决策树,其中调用登了剩下的所有的函数,inTree的形式必须为嵌套的决策树

def createPlot():
    fig=plt.figure(1,facecolor='white') # 新建一个画布,背景设置为白色的
    fig.clf()# 将画图清空
    createPlot.ax1=plt.subplot(111,frameon=False)# 设置一个多图展示,但是设置多图只有一个,
    # 但是设置参数是111,构建了一个1*1的模块,并操作对象指向第一个图。
    plotNode('decision', (0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('leaf', (0.8,0.5),(0.3,0.7),leafNode)
    plt.show()

def plotNode(nodeTxt,centerPt,parentPt,nodeType):#plotNode函数有nodeTxt,centerPt, parentPt, nodeType这四个参数。
                                                 # nodeTxt是注释的文本信息。centerPt表示那个节点框的位置。
                                                 #  parentPt表示那个箭头的起始位置(终点坐标)。nodeType表示的是节点的类型,
                                                #  也就会用我们之前定义的全局变量。#xytext是起点坐标  #va="center",ha="center"是坐标的水平中心和垂直中心
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
                            xytext=centerPt,textcoords='axes fraction',\
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)#annotate是注释的意思,
                                                                        # 也就是作为原来那个框的注释,也是添加一些新的东西
                                                 #arrowprops=arrow_args是结点的颜色
def getNumLeafs(myTree):
    numLeafs=0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    #firstStr=myTree.keys()[0]# 找到输入的第一个元素,第一个关键词为划分数据集类别的标签
    secondDict=myTree[firstStr]# mytree经过第一个特征值分类后的字典
    for key in secondDict.keys():#测试数据是否为字典形式
        if type(secondDict[key]).__name__=='dict':#  type(secondDict[key]).__name__输出的是括号里面的变量的类型,即判断secondDict[key]对应的内容是否为字典类型
            numLeafs+=getNumLeafs(secondDict[key])
        else:  numLeafs+=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    #firstStr=myTree.keys()[0]# 找到输入的第一个元素,第一个关键词为划分数据集类别的标签
    secondDict=myTree[firstStr]# mytree经过第一个特征值分类后的字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth: maxDepth=thisDepth
    return maxDepth
def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]
def plotMidText(cntrPt,parentPt,txtString):#在坐标点cntrPt和parentPt连接线上的中点,显示文本txtString   #parentPt表示那个箭头的起始位置(终点坐标),cntrPt叶节点的位置,箭头的终点

    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]#x轴坐标
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]#y轴坐标
    createPlot.ax1.text(xMid,yMid,txtString)#在(xMid, yMid)处显示txtString
def plotTree(myTree,parentPt,nodeTxt):  # nodeTxt是注释的文本信息
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#cntrPt用来记录当前要画的树的树根的结点位置
    # plotTree.xOff和plotTree.yOff是用来追踪已经绘制的节点位置,plotTree.totalW为这个数的宽度,叶节点数
    #cntrPt用来记录当前要画的树的树根的结点位置在plotTree函数中,它是这样计算的
    # cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # numLeafs记录当前的树中叶子结点个数。我们希望树根在这些所有叶子节点的中间
    # plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 这里的
    # 1.0 + numLeafs 需要拆开来理解,也就是
    # plotTree.xOff + float(numLeafs) / 2.0 / plotTree.totalW + 1.0 / 2.0 / plotTree.totalW
    # plotTree.xOff + 1 / 2 * float(numLeafs) / plotTree.totalW + 0.5 / plotTree.totalW
    # 因为xOff的初始值是 - 0.5 / plotTree.totalW ,是往左偏了0.5 / plotTree.tatalW的,
    # 这里正好加回去。这样cntrPt记录的x坐标正好是所有叶子结点的中心点'''
    plotMidText(cntrPt,parentPt,nodeTxt)#显示节点
    plotNode(firstStr,cntrPt,parentPt,decisionNode)#firstStr为需要显示的文本,cntrPt为文本的中心点,
                                                   # parentPt为箭头指向文本的起始点,decisionNode为文本属性
    secondDict=myTree[firstStr]#子树
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD#totalD是这个数的深度,深度移下一层,初始值为1
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW#x坐标平移一个单位
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)#画叶节点
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))#显示箭头文本
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD#下移一层深度

def createPlot(inTree):#是主函数,调用了plotTree,plotTree又调用了其他的函数
    fig=plt.figure(1,facecolor='white')#创建一个画布,背景为白色
    fig.clf()#画布清空
    axprops=dict(xticks=[],ytichs=[])#定义横纵坐标轴,无内容
    createPlot.ax1 = plt.subplot(111, frameon=False)
    #createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#去掉x、y轴,xticks=[],ytichs=[]无内容,          #**表示此参数是字典参数
    # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
    # createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #frameon表示是否绘制坐标轴矩形,无坐标轴,111代表1X1个图,第一个
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW#如果叶子结点的坐标是 1/totalW , 2/totalW, 3/totalW, …, 1 的话,
                                 # 就正好在宽度的最右边,为了让坐标在宽度的中间,需要减去0.5 / totalW 。
                                 # 所以createPlot函数中,初始化 plotTree.xOff 的值为-0.5/plotTree.totalW。
                                # 这样每次 xOff + 1/totalW ,正好是下1个结点的准确位置
    plotTree.yOff=1.0               #yOff的初始值为1,每向下递归一次,这个值减去 1 / totalD
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

猜你喜欢

转载自blog.csdn.net/youhuakongzhi/article/details/85479093