-
import math
-
import operator
-
-
-
def calcShannonEnt(dataset):
-
numEntries = len(dataset)
-
labelCounts = {}
-
for featVec in dataset:
-
currentLabel = featVec[ -1]
-
if currentLabel not in labelCounts.keys():
-
labelCounts[currentLabel] = 0
-
labelCounts[currentLabel] += 1
-
-
shannonEnt = 0.0
-
for key in labelCounts:
-
prob = float(labelCounts[key]) / numEntries
-
shannonEnt -= prob * math.log(prob, 2)
-
return shannonEnt
-
-
-
def CreateDataSet():
-
'''dataset = [[1, 1, 'yes'],
-
[1, 1, 'yes'],
-
[1, 0, 'no'],
-
[0, 1, 'no'],
-
[0, 1, 'no']]
-
labels = ['outlook', 'temperature','humidity','false']
-
return dataset, labels'''
-
-
-
lines_set = open( 'Dataset.txt').readlines()
-
labelLine = lines_set[ 2];
-
labels = labelLine.strip().split()
-
lines_set = lines_set[ 4: 11]
-
dataSet = [];
-
for line in lines_set:
-
data = line.split();
-
dataSet.append(data);
-
return dataSet, labels
-
-
def splitDataSet(dataSet, axis, value):
-
retDataSet = []
-
for featVec in dataSet:
-
if featVec[axis] == value:
-
reducedFeatVec = featVec[:axis]
-
reducedFeatVec.extend(featVec[axis + 1:])
-
retDataSet.append(reducedFeatVec)
-
-
return retDataSet
-
-
-
def chooseBestFeatureToSplit(dataSet):
-
numberFeatures = len(dataSet[ 0]) - 1
-
baseEntropy = calcShannonEnt(dataSet)
-
bestInfoGain = 0.0;
-
bestFeature = -1;
-
for i in range(numberFeatures):
-
featList = [example[i] for example in dataSet]
-
uniqueVals = set(featList)
-
newEntropy = 0.0
-
for value in uniqueVals:
-
subDataSet = splitDataSet(dataSet, i, value)
-
prob = len(subDataSet) / float(len(dataSet))
-
newEntropy += prob * calcShannonEnt(subDataSet)
-
infoGain = baseEntropy - newEntropy
-
if (infoGain > bestInfoGain):
-
bestInfoGain = infoGain
-
bestFeature = i
-
return bestFeature
-
-
-
def majorityCnt(classList):
-
classCount = {}
-
for vote in classList:
-
if vote not in classCount.keys():
-
classCount[vote] = 0
-
classCount[vote] = 1
-
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter( 1), reverse= True)
-
return sortedClassCount[ 0][ 0]
-
-
-
def createTree(dataSet, labels):
-
classList = [example[ -1] for example in dataSet]
-
if classList.count(classList[ 0]) == len(classList):
-
return classList[ 0]
-
if len(dataSet[ 0]) == 1:
-
return majorityCnt(classList)
-
bestFeat = chooseBestFeatureToSplit(dataSet)
-
bestFeatLabel = labels[bestFeat]
-
myTree = {bestFeatLabel: {}}
-
del (labels[bestFeat])
-
featValues = [example[bestFeat] for example in dataSet]
-
uniqueVals = set(featValues)
-
for value in uniqueVals:
-
subLabels = labels[:]
-
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
-
return myTree
-
-
-
myDat, labels = CreateDataSet()
-
myTree = createTree(myDat, labels)
-
print myTree
运行结果如下:
{'outlook': {'overcast': 'Y', 'sunny': 'N', 'rain': {'windy': {'false': 'Y', 'true': 'N'}}}}
训练集和测试集
-
训练集:
-
-
outlook temperature humidity windy
-
---------------------------------------------------------
-
sunny hot high false N
-
sunny hot high true N
-
overcast hot high false Y
-
rain mild high false Y
-
rain cool normal false Y
-
rain cool normal true N
-
overcast cool normal true Y
-
-
测试集
-
outlook temperature humidity windy
-
---------------------------------------------------------
-
sunny mild high false
-
sunny cool normal false
-
rain mild normal false
-
sunny mild normal true
-
overcast mild high true
-
overcast hot normal false
-
rain mild high true
-