前面记录了特征值的选取,现在我们就来说一下剪枝。
决策树的剪枝
在决策树创建时,由于数据中的噪声和离群点,许多分枝反映的是训练数据中的异常,剪枝方法处理这种过分拟合数据的问题。
有常用的两种剪枝方法:先剪枝和后剪枝。
先剪枝:通过提前停止树的构建(例如,通过决定在给定的结点不再分裂或划分训练元组的子集)而对树"剪枝"。一旦停止,结点就成为树叶。
后剪枝:由"完全生长"的树剪去子树,通过删除结点的分枝并使用树叶替换它而剪掉给定节点上的子树。该树叶的类标号用子树中最频繁的类标记。
决策树的剪枝往往通过极小化决策树整体的损失函数或代价函数来实现。设树的叶结点的个数为,是树的叶结点,该叶结点有个样本点,其中类的样本点有个,,为叶结点上的经验熵,则决策树的损失函数可以定义为:
(损失函数=拟合度+a*模型复杂度)
其中表示模型对训练数据的预测误差,即模型与训练数据的拟合程度。为经验熵;为叶子节点的个数,用来表示模型的复杂度。
树中每个节点t下的树如果被减下去,则整体损失函数减少的程度为: 。
为以t为根结点的子树,为t为单节点树时的预测误差,为以t为根的子树的预测误差。
在算法CART中,在整体树中剪去最小的得到子树,在从中剪去最小的得到子树,如此递归,直到只有根节点,这样就按找到了一个按顺序的子树序列。最后通过交叉验证得到最优的一个子树,作为真正剪枝后的树。
实例代码:
# -*- coding:utf-8 -*-
from math import log
my_data=[['slashdot','USA','yes',18,'None'],
['google','France','yes',23,'Premium'],
['digg','USA','yes',24,'Basic'],
['kiwitobes','France','yes',23,'Basic'],
['google','UK','no',21,'Premium'],
['(direct)','New Zealand','no',12,'None'],
['(direct)','UK','no',21,'Basic'],
['google','USA','no',24,'Premium'],
['slashdot','France','yes',19,'None'],
['digg','USA','no',18,'None'],
['google','UK','no',18,'None'],
['kiwitobes','UK','no',19,'None'],
['digg','New Zealand','yes',12,'Basic'],
['slashdot','UK','no',21,'None'],
['google','UK','yes',18,'Basic'],
['kiwitobes','France','yes',19,'Basic']]
class decisionnode:
def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
self.col=col #待检验的判断条件所对应的列索引值
self.value=value #对应与为了使结果为true,当前列必须匹配的值
self.results=results #保存的是针对当前分支的结果,它是一个字典,除叶节点以外,在其他节点上该值都是None
self.tb=tb #对应于结果为true时,当前节点的子树上的节点
self.fb=fb #对应于结果为false时,当前节点的子树上的节点
#CART算法
#在某一列上对数据集合进行拆分,能够处理数值型数据或名词性数据
def divideset(rows,column,value):
split_function=None
if isinstance(value,int) or isinstance(value,float):
split_function=lambda row:row[column]>=value #针对连续参数
else:
split_function=lambda row:row[column] ==value #针对离散参数
#将函数拆分为两个集合,并返回
set1=[row for row in rows if split_function(row)]
set2=[row for row in rows if not split_function(row)]
return (set1,set2)
#对各种可能的结果进行计数(每一行数据的最后一列记录了这一计数结果)
def uniquecounts(rows):
results={}
for row in rows:
#计数结果在最后一行
r=row[len(row)-1]
if r not in results:
results[r]=0
results[r]+=1
return results
#Gini不纯度计算
#随机放置的数据项出现于错误分类中的概率
# def giniimpurity(rows):
# total=len(rows)
# counts=uniquecounts(rows)
# imp=0
# for k1 in counts:
# p1=float(counts[k1])/total
# for k2 in counts:
# if k1==k2:
# continue
# p2=float(counts[k2])/total
# imp+= p1*p2
# return imp
def gini(rows):
total=len(rows)
counts=uniquecounts(rows)
imp=0
for k1 in counts:
p1=float(counts[k1])/total
imp+=p1*p1
return 1-imp
#方差
def variance(rows):
if len(rows)==0:
return 0
data=[float(row[len(row)-1]) for row in rows]
mean=sum(data)/len(data)
variance=sum([(d-mean)**2 for d in data])/len(data)
return variance
#熵是遍历所有可能的结果之后所得到的p()log(p(x))之和
def entropy(rows):
log2=lambda x:log(x)/log(2)
results=uniquecounts(rows)
ent=0.0
for r in results.keys():
p=float(results[r])/len(rows)
ent=ent-p*log2(p)
return ent
def buildtree(rows,coref=entropy):
if len(rows)==0:
return decisionnode()
current_score=coref(rows)
#定义一些变量以记录最佳拆分条件
best_gain=0.0
best_criteria=None
best_sets=None
column_count=len(rows[0])-1 #除去最后一列,因为最后一列为结果
for col in range(0,column_count):
#在当前列中生成一个由不停值构成的序列
column_values={}
for row in rows:
column_values[row[col]]=1
#接下来根据这一列中的每个值,尝试对数据集进行划分
for value in column_values.keys():
(set1,set2)=divideset(rows,col,value)
#信息增益
p=float(len(set1))/len(rows)
gain=current_score-p*coref(set1)-(1-p)*coref(set2)
if gain>best_gain and len(set1)>0 and len(set2)>0:
best_gain=gain
best_criteria=(col,value)
best_sets=(set1,set2)
#创建子分支
if best_gain>0:
trueBranch=buildtree(best_sets[0])
falseBranch=buildtree(best_sets[1])
return decisionnode(col=best_criteria[0],value=best_criteria[1],tb=trueBranch,fb=falseBranch)
else:
return decisionnode(results=uniquecounts(rows))
#对新的观测数据进行分类
def classfiy(observation,tree):
if tree.results==None:
return tree.results
else:
v=observation[tree.col]
branch=None
if isinstance(v,int) or isinstance(v,float):
if v>=tree.value:
branch=tree.tb
else:
branch=tree.fb
else:
if v==tree.value:
branch=tree.tb
else:
branch=tree.fb
return classfiy(observation,branch)
def mdclassfiy(observation,tree):
if tree.results!=None:
return tree.results
else:
v=observation[tree.col]
if v==None:
tr,fr=mdclassfiy(observation,tree.tb),mdclassfiy(observation,tree.fb)
tcount=sum(tr.values())
fcount=sum(fr.values())
tw=float(tcount)/(tcount+fcount)
fw=float(fcount)/(tcount+fcount)
result={}
for k,v in tr.itmes():
result[k]=v*tw
for k,v in fr.items():
if k not in result:
result[k]=0
result[k]+=v*fw
return result
else:
if isinstance(v,int) or isinstance(v,float):
if v>=tree.values:
branch=tree.tb
else:
branch=tree.fb
else:
if v==tree.value:
branch=tree.tb
else:
branch=tree.fb
return mdclassfiy(observation,branch)
#决策树的剪枝
def prune(tree,mingain):
#如果分支不是叶节点,则对其进行剪枝操作
if tree.tb.results==None:
prune(tree.tb,mingain)
if tree.fb.results==None:
prune(tree.fb,mingain)
#如果两个子分支都是叶节点,则判断他们是否要合并
if tree.tb.results!=None and tree.fb.results!=None:
#构造合并后的数据集
fb,tb=[],[]
for v,c in tree.tb.results.items():
tb+=[[v]]*c
for v,c in tree.fb.results.items():
fb+=[[v]]*c
#检查熵的减少情况
delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
if delta<mingain:
#合并分支
tree.tb,tree.fb=None,None
tree.results=uniquecounts(tb+fb)
if __name__ == '__main__':
print(divideset(my_data,2,'yes'))
print(uniquecounts(my_data))
#print(giniimpurity(my_data))
print(gini(my_data))
print(entropy(my_data))
set1,set2=divideset(my_data,2,'yes')
print(entropy(set1))
print(gini(set1))
print(buildtree(my_data))
print(classfiy(['(direct)','USA','yes',5],buildtree(my_data)))