决策树的构建、展示与决策

1. 概述

上一篇日志中,我们介绍了两个决策树构建算法 – ID3、C4.5:
决策树的构建算法 – ID3 与 C4.5 算法
本篇日志我们来看看如何使用这两个算法以及其他工具构建和展示我们的决策树。

2. 使用 C4.5 构建决策树

有了上一篇日志中,我们介绍的 ID3 与 C4.5 算法,递归进行计算,选出每一层当前的最佳特征以及最佳特征对应的最佳划分特征值,我们就可以构建出完整的决策树了:

# 此处有图片

流程图非常清晰,上图的基本思想是,对于数值型特征,我们只分为左右两分支,以防止子树过多,同时也避免多种分发造成的系统复杂度过高,而对于字符串描述性特征,我们按照特征取值个数来进行子树划分,因为通常来说,数值型特征取值会非常多,而字符串描述性特征则不会。

2.1. python 代码实现

# -*- coding: UTF-8 -*-
# {{{
import operator
from math import log


def createDataSet():
    """
    创建数据集

    :return: 数据集与特征集
    """
    dataSet = [[706, 'hot', 'sunny', 'high', 'false', 'no'],
               [706, 'hot', 'sunny', 'high', 'true', 'no'],
               [706, 'hot', 'overcast', 'high', 'false', 'yes'],
               [709, 'cool', 'rain', 'normal', 'false', 'yes'],
               [710, 'cool', 'overcast', 'normal', 'true', 'yes'],
               [712, 'mild', 'sunny', 'high', 'false', 'no'],
               [714, 'cool', 'sunny', 'normal', 'false', 'yes'],
               [715, 'mild', 'rain', 'normal', 'false', 'yes'],
               [720, 'mild', 'sunny', 'normal', 'true', 'yes'],
               [721, 'mild', 'overcast', 'high', 'true', 'yes'],
               [722, 'hot', 'overcast', 'normal', 'false', 'yes'],
               [723, 'mild', 'sunny', 'high', 'true', 'no'],
               [726, 'cool', 'sunny', 'normal', 'true', 'no'],
               [730, 'mild', 'sunny', 'high', 'false', 'yes']]
    labels = ['日期', '气候', '天气', '气温', '寒冷']
    return dataSet, labels


def classCount(dataSet):
    """
    获取每个特征出现的次数

    :param dataSet: 数据集
    :return:
    """

    labelCount = {}
    for one in dataSet:
        if one[-1] not in labelCount.keys():
            labelCount[one[-1]] = 0
        labelCount[one[-1]] += 1
    return labelCount


def calcShannonEntropy(dataSet):
    """
    计算系统信息熵

    :param dataSet: 数据集
    :return:
    """

    labelCount = classCount(dataSet)
    numEntries = len(dataSet)
    Entropy = 0.0
    for i in labelCount:
        prob = float(labelCount[i]) / numEntries
        Entropy -= prob * log(prob, 2)
    return Entropy


def majorityClass(dataSet):
    """
    找到对应结果最多的特征

    :param dataSet: 数据集
    :return:
    """
    labelCount = classCount(dataSet)
    sortedLabelCount = sorted(labelCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedLabelCount[0][0]


def splitDataSet(dataSet, i, value):
    """
    非数值型特征划分
    将 dataset 以第 i 个特征值为 value 作为基准划分为多个部分

    :param dataSet: 数据集
    :param i: 特征索引
    :param value: 划分基准值
    :return:
    """

    subDataSet = []
    for one in dataSet:
        if one[i] == value:
            reduceData = one[:i]
            reduceData.extend(one[i + 1:])
            subDataSet.append(reduceData)
    return subDataSet


def splitContinuousDataSet(dataSet, i, value, direction):
    """
    数值型特征划分
    将 dataset 以第 i 个特征值为 value 作为基准划分为多个部分

    :param dataSet: 数据集
    :param i: 特征索引
    :param value: 划分基准值
    :param direction: 0. 左侧, 1. 右侧
    :return:
    """

    subDataSet = []
    for one in dataSet:
        if direction == 0:
            if one[i] > value:
                reduceData = one[:i]
                reduceData.extend(one[i + 1:])
                subDataSet.append(reduceData)
        if direction == 1:
            if one[i] <= value:
                reduceData = one[:i]
                reduceData.extend(one[i + 1:])
                subDataSet.append(reduceData)
    return subDataSet


def chooseBestFeat(dataSet, labels):
    """
    获取最佳特征与特征对应的最佳划分值

    :param dataSet: 数据集
    :param labels: 特征集
    :return:
    """

    global bestSplit
    """ 计算划分前系统的信息熵 """
    baseEntropy = calcShannonEntropy(dataSet)
    bestFeat = 0
    baseGainRatio = -1
    numFeats = len(dataSet[0]) - 1
    bestSplitDic = {}

    """ 遍历每个特征 """
    for i in range(numFeats):
        """ 获取该特征所有值 """
        featVals = [example[i] for example in dataSet]
        uniVals = sorted(set(featVals))
        if type(featVals[0]).__name__ == 'float' or type(featVals[0]).__name__ == 'int':

            """ 用于区分的坐标值 """
            splitList = []
            for j in range(len(uniVals) - 1):
                splitList.append((uniVals[j] + uniVals[j + 1]) / 2.0)

            """ 计算信息增益比,找到最佳划分属性与划分阈值 """
            for j in range(len(splitList)):

                """ 该划分情况下熵值 """
                newEntropy = 0.0
                splitInfo = 0.0
                value = splitList[j]

                """ 划分出左右两侧数据集 """
                subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
                subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)

                """ 计算划分后系统信息熵 """
                prob0 = float(len(subDataSet0)) / len(dataSet)
                newEntropy -= prob0 * calcShannonEntropy(subDataSet0)
                prob1 = float(len(subDataSet1)) / len(dataSet)
                newEntropy -= prob1 * calcShannonEntropy(subDataSet1)

                """ 获取惩罚参数 """
                splitInfo -= prob0 * log(prob0, 2)
                splitInfo -= prob1 * log(prob1, 2)

                """ 计算信息增益比 """
                gainRatio = float(baseEntropy - newEntropy) / splitInfo

                if gainRatio > baseGainRatio:
                    baseGainRatio = gainRatio
                    bestSplit = j
                    bestFeat = i

            bestSplitDic[labels[i]] = splitList[bestSplit]
        else:
            splitInfo = 0.0
            newEntropy = 0.0
            for value in uniVals:
                """ 划分数据集 """
                subDataSet = splitDataSet(dataSet, i, value)

                """ 计算划分后系统信息熵 """
                prob = float(len(subDataSet)) / len(dataSet)
                newEntropy -= prob * calcShannonEntropy(subDataSet)

                """ 获取惩罚参数 """
                splitInfo -= prob * log(prob, 2)

            """ 计算信息增益比 """
            gainRatio = float(baseEntropy - newEntropy) / splitInfo
            if gainRatio > baseGainRatio:
                bestFeat = i
                baseGainRatio = gainRatio

    bestFeatValue = None
    if type(dataSet[0][bestFeat]).__name__ == 'float' or type(dataSet[0][bestFeat]).__name__ == 'int':
        bestFeatValue = bestSplitDic[labels[bestFeat]]
    if type(dataSet[0][bestFeat]).__name__ == 'str':
        bestFeatValue = labels[bestFeat]
    return bestFeat, bestFeatValue


def createTree(dataSet, labels):
    """
    递归创建决策树

    :param dataSet: 数据集
    :param labels: 特征指标集
    :return: 决策树字典结构
    """
    classList = [example[-1] for example in dataSet]

    if len(set(classList)) == 1:
        return classList[0]

    if len(dataSet[0]) == 1:
        return majorityClass(dataSet)

    """ 找到当前的最佳划分属性与划分阈值 """
    bestFeat, bestFeatLabel = chooseBestFeat(dataSet, labels)

    myTree = {labels[bestFeat]: {}}
    subLabels = labels[:bestFeat]
    subLabels.extend(labels[bestFeat + 1:])

    if type(dataSet[0][bestFeat]).__name__ == 'str':
        featVals = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featVals)

        """ 递归创建左右子树 """
        for value in uniqueVals:
            """ 获取去除该特征数据集 """
            reduceDataSet = splitDataSet(dataSet, bestFeat, value)
            myTree[labels[bestFeat]][value] = createTree(reduceDataSet, subLabels)

    if type(dataSet[0][bestFeat]).__name__ == 'int' or type(dataSet[0][bestFeat]).__name__ == 'float':
        value = bestFeatLabel

        """ 划分数据集 """
        greaterDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 0)
        smallerDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 1)

        """ 递归创建左右子树 """
        myTree[labels[bestFeat]]['>' + str(value)] = createTree(greaterDataSet, subLabels)
        myTree[labels[bestFeat]]['<=' + str(value)] = createTree(smallerDataSet, subLabels)
    return myTree


if __name__ == '__main__':
    dataSet, labels = createDataSet()
    print(createTree(dataSet, labels))
    #}}}

返回了:

{
  '日期': {
    '>728.0': 'yes',
    '<=728.0': {
      '寒冷': {
        'false': {
          '气温': {
            'high': {
              '气候': {
                'hot': {
                  '天气': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                },
                'mild': 'no'
              }
            },
            'normal': 'yes'
          }
        },
        'true': {
          '气温': {
            'high': {
              '气候': {
                'hot': 'no',
                'mild': {
                  '天气': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                }
              }
            },
            'normal': {
              '气候': {
                'mild': 'yes',
                'cool': {
                  '天气': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}

3. 决策树的可视化

上面的 json 结果看上去非常不清楚,我们可不可以画出决策树的树结构呢?
我们可以利用 matplotlib 模块来实现树结构的绘制:

# -*- coding: UTF-8 -*-
# {{{
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties


def getNumLeafs(myTree):
    """
    获取决策树叶子结点的数目

    :param myTree: 决策树
    :return: 决策树的叶子结点的数目
    """
    numLeafs = 0  # 初始化叶子
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]  # 获取下一组字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    获取决策树的层数

    :param myTree: 决策树
    :return: 决策树的层数
    """
    maxDepth = 0  # 初始化决策树深度
    firstStr = next(iter(
        myTree))  # python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]  # 获取下一个字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth  # 更新层数
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制节点

    :param nodeTxt: 结点名
    :param centerPt: 文本位置
    :param parentPt: 标注的箭头位置
    :param nodeType: 结点格式
    :return:
    """
    arrow_args = dict(arrowstyle="<-")  # 定义箭头格式
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)  # 设置中文字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',  # 绘制结点
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)


def plotMidText(cntrPt, parentPt, txtString):
    """
    标注有向边属性值

    :param cntrPt: 当前节点
    :param parentPt: 父节点
    :param txtString: 标注内容
    :return:
    """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 计算标注位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制决策树

    :param myTree: 决策数字典
    :param parentPt: 父节点
    :param nodeTxt: 节点名
    :return:
    """
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")  # 设置叶结点格式
    numLeafs = getNumLeafs(myTree)  # 获取决策树叶结点数目,决定了树的宽度
    firstStr = next(iter(myTree))  # 下个字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)  # 中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制结点
    secondDict = myTree[firstStr]  # 下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key], cntrPt, str(key))  # 不是叶结点,递归调用继续绘制
        else:  # 如果是叶结点,绘制叶结点,并标注有向边属性值
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    """
    创建绘制面板

    :param inTree: 决策树字典
    :return:
    """
    fig = plt.figure(1, facecolor='white')  # 创建 fig
    fig.clf()  # 清空 fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # 去掉 x、y 轴
    plotTree.totalW = float(getNumLeafs(inTree))  # 获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))  # 获取决策树层数
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0  # x偏移
    plotTree(inTree, (0.5, 1.0), '')  # 绘制决策树
    plt.show()  # 显示绘制结果


if __name__ == '__main__':
    myTree = {'日期':
                  {'>728.0': 'yes',
                   '<=728.0':
                       {'寒冷':
                            {'false':
                                 {'气温':
                                      {'high':
                                           {'气候':
                                                {'hot':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'
                                                           }
                                                      },
                                                 'mild': 'no'}
                                            },
                                       'normal': 'yes'}
                                  },
                             'true':
                                 {'气温':
                                      {'high':
                                           {'气候':
                                                {'hot': 'no',
                                                 'mild':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'}
                                                      }
                                                 }                                            
                                            },
                                       'normal':
                                           {'气候':
                                                {'mild': 'yes',
                                                 'cool':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'
                                                           }
                                                      }
                                                 }
                                            }
                                       }
                                  }
                             }
                        }
                   }
              }
    print(myTree)
    createPlot(myTree)
    # }}}

后面博主会专门写一篇博客全面介绍 matplotlib 的使用。
程序执行最终打印出了:

在这里插入图片描述

4. 预测

既然构建好了我们的决策树,接下来我们就可以预测决策了:

# -*- coding: UTF-8 -*-
# {{{
import re


def predict(inputTree, featLabels, testVec):
    firstStr = next(iter(inputTree)) #获取决策树结点
    secondDict = inputTree[firstStr] #下一个字典
    featIndex = featLabels.index(firstStr)
    classLabel = None
    for key in secondDict.keys():
        into = (testVec[featIndex] == key)
        if type(testVec[featIndex]).__name__ == 'float' or type(testVec[featIndex]).__name__ == 'int':
            i = float(re.findall("\d+", key)[0])
            if (key[:1] == '>' and testVec[featIndex] > i) or (key[:2] == '<=' and testVec[featIndex] <= i):
                into = True
        if into:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = predict(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    myTree = {'日期':
                  {'>728.0': 'yes',
                   '<=728.0':
                       {'寒冷':
                            {'false':
                                 {'气温':
                                      {'high':
                                           {'气候':
                                                {'hot':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'
                                                           }
                                                      },
                                                 'mild': 'no'}
                                            },
                                       'normal': 'yes'}
                                  },
                             'true':
                                 {'气温':
                                      {'high':
                                           {'气候':
                                                {'hot': 'no',
                                                 'mild':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'}
                                                      }
                                                 }                                            
                                            },
                                       'normal':
                                           {'气候':
                                                {'mild': 'yes',
                                                 'cool':
                                                     {'天气':
                                                          {'sunny': 'no',
                                                           'overcast': 'yes'
                                                           }
                                                      }
                                                 }
                                            }
                                       }
                                  }
                             }
                        }
                   }
              }

    testVec = [810, 'false']  # 测试数据
    result = predict(myTree, ['日期', '寒冷', '气温', '气候', '天气'], testVec)
    if result == 'yes':
        print('打高尔夫')
    if result == 'no':
        print('不打高尔夫')
    # }}}

打印出了:

打高尔夫

欢迎关注微信公众号

在这里插入图片描述

参考资料

Peter Harrington 《机器学习实战》。
https://en.wikipedia.org/wiki/ID3_algorithm。
https://en.wikipedia.org/wiki/C4.5_algorithm。
https://blog.csdn.net/c406495762/article/details/76262487。

猜你喜欢

转载自blog.csdn.net/DILIGENT203/article/details/83689594