版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27668313/article/details/79063222
先上代码,理论有空补上,采用python 3.X编写,没有剪枝部分
import math
import operator
# 计算数据集的信息熵
def calcEntropy(data):
# label = []
numClass = {}
Entropy = 0.0
label = [sample[-1] for sample in data]
for i in label:
numClass[i] = numClass.get(i, 0) + 1
for i in numClass:
prob = float(numClass[i]/len(label)) # 计算每个类别的概率
Entropy = Entropy - prob * math.log(prob, 2) # 计算每个类别熵之和
return Entropy
# 将data中特征为i,且特征值等于setValue的数据取出
def splitData(data, i, setValue):
subData = []
for sample in data:
if sample[i] == setValue:
reducedSample = sample[:i] # 删除样本的i特征数据
reducedSample.extend(sample[i+1:])
subData.append(reducedSample)
return subData
def selAttribute(data):
totalEntropy = calcEntropy(data) # 计算数据熵
IntiGainEn = 0.0 # 初始化信息增益
for i in range(len(data[0])-1): # 遍历data中每个特征
valueList = [sample[i] for sample in data] # 每个样本特征i 的取值
numvalue = {}
Entropy = 0.0
for i in valueList:
numvalue[i] = numvalue.get(i, 0) + 1
value = set(valueList) # 特征i 的所有不同值
for value in value:
subData = splitData(data, i, value)
subEntropy = calcEntropy(subData)
prob = float(numvalue[value]/len(valueList))
Entropy = Entropy + prob * subEntropy
GainEn = totalEntropy - Entropy # 信息增益
if GainEn > IntiGainEn:
return i # 返回分割特征的索引
else:
IntiGainEn = GainEn
# 对于最后的划分特征,选该特征下样本类型最多的作为该叶节点的类型
def majorVote(classList):
classCount = {}
for i in classList:
classCount[i] = classCount.get(i, 0) + 1
sortedClassCount = sorted(classCount.items, key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(data, attribute): # data是数据集,attribute是数据集的特征
classList = [Class[-1] for Class in data]
if len(set(classList)) == 1: # 如果classList中的类别都相同,则将该类别作为该叶子节点的标记
return classList[0]
if len(data[0]) == 1: # 如果data中只剩下一个特征了,则将该特征下样本类别最多的类别作为该节点标记
return majorVote(classList)
attributeIndex = selAttribute(data) # 返回最佳划分特征的索引
bestAttribute = attribute[attributeIndex] # 按索引从数据集特征中找到最佳特征
myTree = {bestAttribute: {}} # 以字典方式存储数
del(attribute[attributeIndex]) # 删除上面使用过的特征
attributeValue = [value[attributeIndex] for value in data]
brach = set(attributeValue)
for value in brach:
subattribute = attribute[:] # 复制,防止其它地方修改
subData = splitData(data, attributeIndex, value)
myTree[bestAttribute][value] = createTree(subData, subattribute)
return myTree
def createDataSet():
data = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
attribute = ['no surfacing', 'flippers']
# data = [[1, 0, 'good'], [1, 0, 'good'], [0, 0, 'bad'], [0, 1, 'bad'], [1, 1, 'bad']]
# attribute = ['根蒂', '纹理']
return data, attribute
if __name__ == '__main__':
data, attribute = createDataSet()
Tree = createTree(data, attribute)
print(Tree)