什么是决策树算法
决策树是基本的分类回归算法,呈树状结构,在分类任务中,是基于特征对实例进行分类的过程。可以认为是,在给定特征条件下类的条件概率。
决策树算法种类
- ID3:基于信息增益选择特征的算法
- C4.5:基于信息增益率选择特征的算法
- CART:基于gini系数选择特征的算法
代码
#DT
import numpy as np
from math import log
import operator
def ShannonEnt(dataSet):
num = len(dataSet)
labelCount = {}
for data in dataSet:
curLabel = data[-1]
if curLabel not in labelCount.keys():
labelCount[curLabel] = 0
labelCount[curLabel] += 1
shannonEnt = 0.0
for key in labelCount:
prob = float(labelCount[key]) / num
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataset(dataSet, axis, value):
newDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
newDataSet.append(featVec)
return newDataSet
def creatDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
label = ['no surfacing', 'flippers']
return dataSet, label
def choostBestFeature(dataSet):
num_feature = len(dataSet[0]) - 1
base_entropy = ShannonEnt(dataSet)
best_info_gain = 0.0
for i in range(num_feature):
feat_value = set([example[i] for example in dataSet])
feat_entropy = 0.0
for value in feat_value:
sub_dataSet = splitDataset(dataSet, i, value)
prob = float(len(sub_dataSet) / len(dataSet))
feat_entropy += prob * ShannonEnt(sub_dataSet)
info_gain = base_entropy - feat_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
def labelOfLeaf(class_list):
class_count = {}
for vote in class_list:
if vote not in class_list:
class_list[vote] = 0
class_list[vote] += 1
sorted_class_list = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_list[0][0]
def creatTree(dataSet, labels):
class_list = [example[-1] for example in dataSet]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
if len(dataSet[0]) == 1:
return labelOfLeaf(class_list)
best_feat = choostBestFeature(dataSet)
best_feat_label = labels[best_feat]
myTree = {best_feat_label:{}}
unique_feat_values = set([example[best_feat] for example in dataSet])
for value in unique_feat_values:
myTree[best_feat_label][value] = creatTree(splitDataset(dataSet, best_feat, value), labels)
return myTree
这个代码主要是机器学习实战上的,但书上有挺多错误,我都已经修改好了,在我电脑上是没有出现错误。有啥不对的,还请各位大神指教。这也算是手撸一次代码复习一下。
参考
机器学习实战 -Peter Harrington