决策树ID3算法及其Python实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/chai_zheng/article/details/78226409

决策树是一个有向无环图,由节点和有向边组成,根节点代表所有的样例,内部节点表示样例的一个属性,叶节点代表一个类。我们先来看WikiPedia上给出的例子,从而对决策树有一个直观理解。


1

这个图里,我们可以看到,是否出门浪要受到几个变量的影响:天气、温度、湿度、多云这四个,是一个14行5列的数据集。根据这个数据集,我们可以得到下面的决策树。


2

最开始的根节点上,包括14个样例,分别是9个浪、5个不浪。受三种天气影响,可以分为三个分支,在sunny这一分支里共5个样例,根据humidity是否高于70%可以将这5个样例再分为2个和3个,2个全部是浪,3个全部是在家待着,因此到这,最左边的这一条线路已经完全分到不能再分,其他线路也是同理。到这里,我们可以看出来决策树是一种解释性非常好的模型,非常符合人脑的推理特性,其运算速度也很快,只需按分支划分就可以。

图很好理解,来想一下这个图怎么生成的。刚开始根据天气情况分成三个分支,天气这个属性是怎么确定的?为什么先分天气,不分温度、湿度?这是第一个问题。按天气左侧分支往下走,到第二个湿度属性,这个属性的确定也是问题,再沿左侧分支往下走,这里就是一个叶节点了,叶节点如何确定?这是第三个问题。下面,我们来解决这三个问题。

我们自然希望选的第一个划分属性具有最好的分类效果,从而提高决策树的分类能力。这是机器学习中特征工程领域的重要问题,一般称为特征选择问题,常用的方法有filter过滤法、wrapper缠绕法、embedded嵌入法,具体我不再展开说,有机会详述一下。通常,决策树ID3算法采用的是filter法中的互信息方法来选择最优属性。来看下互信息的定义:

Gain(D,a)=H(D)H(D|A)

其中, D 为数据集, H(D) 代表数据集 D 的经验熵, H(D|A) 代表了给定特征 A 后的经验条件熵。熵代表了随机变量不确定的程度,熵越大,越混乱。定义为:
H(p)=(pilog2pi)

当随机变量只取两个值0/1时,X为0-1分布,则X的熵为:
H(p)=plog2p(1p)log2(1p)

p=0 p=1 时,随机变量没有不确定性,熵为0, p=0.5 时,熵取值最大为1。

另外一项,条件熵定义为:

H(D|A)=|Dv||D|H(Dv)

Dv 代表了 D 中第 a 个属性上取值为 av 的样例个数。

互信息表征了由于特征 A 的已知,导致数据集 D 的不确定性减少的程度。那么很显然,互信息越大的特征具有更强的分类能力。所以我们可以通过计算样例所有属性的互信息值,来决定最优属性的选择。这就是著名的ID3算法所采用的准则。

容易发现的一个缺点是,这个准则对可取值数目较多的属性有所偏好。考虑极限情况,有一个属性为每一个样本都划分了一个分支,那么条件熵 H(D|A)=0 ,此时互信息最大,但显然这个属性不是最优。为了调整这个缺陷,C4.5算法在互信息的基础上多除了一个分母,从而缓解此情况。除此之外,CART决策树通过定义基尼指数来寻找最优属性,其被定义为:

Gini(D)=pkpk

直观理解,就是从数据集中任选两个样本,其类别标记不一致的概率。值越小,数据集纯度越高,与之对应的还有针对属性 A 定义的基尼系数,读者可自行查看相关文献。

有了最优属性的选择准则,我们就可以通过不断迭代来生成更多的子树。终止迭代条件是什么呢?有两个条件。像刚开始是否要出去浪那个例子一样,如果某一个节点下的所有样本都属于同一类,那么这个节点显然是一个叶节点。另外,如果在某条线路上,我已经使用了所有的属性,仍然在最后一个属性上也没有把数据完全剩成一类,那么就选其中个数最多的类作为这个叶节点的类别标记。

在实际使用决策树的时候,经常会发现,决策树在学习样本的时候过多地考虑了如何对训练数据进行正确归类,从而构建了一颗过于复杂的决策树,从而产生了过拟合现象。缓解的办法就是来主动地给大树减掉一些分支,来降低过拟合风险。

这里我们利用验证数据。即,将数据划分为训练数据和测试数据,训练数据再划分成训练数据和验证数据,在利用训练数据生成树的时候,利用验证数据来测试决策树的泛化性能,从而决定如何剪枝。具体操作有两种:

一、预剪枝。在生成树的时候,每到一个节点,我们先估计一下,有这个节点和将这个节点作为叶节点这两种情况,在验证集上的精度是否有变化,从而决定是否保留这个节点。

二、后剪枝。我们先按照正常情况完整地生成一棵树,然后从叶节点向上,逐步对非叶节点进行计算,如果将该节点对应的子树换成叶节点可以在验证集上获得更高的精度,那我们就把这个子树剪掉。

可以看到,预剪枝使得决策树的很多分支都没有展开,这显然降低了过拟合的风险和训练开销,但另一方面,尽管某些分支的当前划分会导致泛化性能的下降,但后面展开的分支可能会提升性能,这样粗暴的剪掉,可能避免了过拟合,但造成了欠拟合。后剪枝显然欠拟合的风险就比较小,但训练开销比较大。这二者各自的优缺点。

Python源码

注:本程序基于Peter Harrington的《Machine Learning in Action》做了一些改动,向先辈致敬。

# !/usr/bin/env python3
# coding=utf-8
"""
Decision Tree
Author  :Chai Zheng
Blog    :http://blog.csdn.net/chai_zheng/
Github  :https://github.com/Chai-Zheng/Machine-Learning
Email   :[email protected]
Date    :2017.10.12
Basis   :Machine Learning in Action,Peter Harrington
"""

import math
import operator

#计算数据的信息熵
def calcEntropy(data):
    numSamples = len(data)
    numClass = {}
    Entropy = 0.0
    label = [sample[-1] for sample in data]
    for i in label:
        numClass[i] = numClass.get(i,0)+1   #求不同类的数量
    for j in numClass:
        prob = float(numClass[j]/numSamples)
        Entropy = Entropy - prob*math.log(prob,2)
    return Entropy

#取出数据中第i列值为setValue的样本
def splitData(data,i,setValue):
    subData = []
    for sample in data:
        if sample[i] == setValue:
            reducedSample = sample[:i]    #删除该样本的第i列
            reducedSample.extend(sample[i+1:])
            subData.append(reducedSample)
    return subData

#选择最优属性
def slctAttribute(data):
    allEntropy = calcEntropy(data)
    numSamples = len(data)
    numAttributes = len(data[0])-1
    initMI = 0.0
    for i in range(numAttributes):
        valueList = [sample[i] for sample in data]  #拿出数据的第i列
        value = set(valueList)  #拿出这一列的所有不等值
        numEntropy = 0.0
        for j in value:
            subData = splitData(data,i,j)
            proportion = float(len(subData)/numSamples)
            Entropy = calcEntropy(subData)
            numEntropy = numEntropy + Entropy*proportion
        MI = allEntropy - numEntropy    #计算互信息
        if MI > initMI:
            initMI = MI
            slcAttribute = i
    return slcAttribute

#属性已遍历到最后一个,取该属性下样本最多的类为叶节点类别标记
def majorVote(classList):
    classCount = {}
    for i in classList:
        #第一次进入,分别把classList的不同值赋给classCount的键值
        if i not in classCount.keys():
            #构建键值对,用于对每个classList的不同元素来计数
            classCount[i] = 0
        else:
            classCount[i] += 1
    #按每个键的键值降序排列
    sortClassCount = sorted(classCount.items,key = operator.itemgetter(1),reverse = True)
    return sortClassCount[0][0]

def createTree(data,attributes):
    classList = [i[-1] for i in data]   #取data的最后一列(标签值)
    #count出classList中第一个元素的数目,如果和元素总数相等,那么说明样本全部属于某一类,此时结束迭代
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(data[0]) == 1: #遍历后只剩下一个属性,那么标记类别为样本最多的类
        return majorVote(classList)
    selectAttribute = slctAttribute(data)
    bestAttribute = attributes[selectAttribute]
    myTree = {bestAttribute:{}} #生成树,采用字典嵌套的方式记录树
    del(attributes[selectAttribute])    #删除此时的最优属性
    attributeValue = [sample[selectAttribute] for sample in data] #取出data所有样本的第selectAttribute个变量的值
    branch = set(attributeValue)    #取唯一取值,作为本节点的所有分支
    for value in branch:
        subAttributes = attributes[:]
        myTree[bestAttribute][value] = createTree(splitData(data,selectAttribute,value),subAttributes)  #迭代生成子树
    return myTree

if __name__ == '__main__':
    #西瓜数据集,根蒂=1代表蜷缩,0代表硬挺;纹理=1代表模糊,0代表清晰
    data = [[1,0,'good'],[1,0,'good'],[0,0,'bad'],[0,1,'bad'],[1,1,'bad']]
    attributes = ['根蒂','纹理']
    Tree = createTree(data,attributes)
    print(Tree)

运行结果如下:


3

本代码只是根据训练数据生成了一颗决策树。可以看到,在对西瓜进行判别的问题上,决策树刚开始选择根蒂为根节点,如果根蒂硬挺,那么直接判别西瓜不好吃;如果根蒂蜷缩,那么以纹理为子树根节点,纹理清晰则为好吃,纹理模糊则不好吃。

猜你喜欢

转载自blog.csdn.net/chai_zheng/article/details/78226409