决策树代码
from math import log
import operator
def calc_shannon_ent(data_set):
num_entries = len(data_set)
label_counts = { }
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) /num_entries
shannon_ent -= prob*log(prob,2)
return shannon_ent
def split_data_set(data_set, axis, value):
return_data_set = []
for feat_vec in data_set:
if feat_vec[axis] == value:
reduce_feat_vec = feat_vec[:axis]
reduce_feat_vec.extend(feat_vec[axis+1 :])
return_data_set.append(reduce_feat_vec)
return return_data_set
def choose_best_feature_to_split(data_set):
num_features = len(data_set[0]) - 1
bese_entropy = calc_shannon_ent(data_set)
best_info_gain = 0.0
best_feature = -1
for i in range(num_features):
feat_list = [example[i] for example in data_set]
unique_vals = set(feat_list)
new_entropy = 0.0
for value in unique_vals:
sub_data_set = split_data_set(data_set, i, value)
prob = len(sub_data_set)/float(len(data_set))
new_entropy += prob * calc_shannon_ent(sub_data_set)
info_gain = bese_entropy - new_entropy
if info_gain > best_info_gain :
best_info_gain = info_gain
best_feature = i
return best_feature
def majority_cnt(class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def create_tree(data_set, labels):
class_list = [example[-1] for example in data_set]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
if len(data_set[0]) == 1:
return majority_cnt(class_list)
best_feat = choose_best_feature_to_split(data_set)
best_feat_label = labels[best_feat]
my_tree = { best_feat_label:{} }
del(labels[best_feat])
feat_values = [example[best_feat] for example in data_set]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
my_tree[best_feat_label][value] = create_tree(split_data_set(data_set, best_feat, value), sub_labels)
return my_tree
绘制代码
import matplotlib.pyplot as plt
import decison_tree
decision_node = dict(boxstyle='sawtooth', fc='0.8')
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plot_node(node_text, center_pt, parent_pt, node_type):
creat_plot.axl.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va='center', ha='center', bbox=node_type,)
creat_plot.axl.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va='center', ha='center', bbox=node_type, arrowprops=arrow_args)
def creat_plot():
fig = plt.figure(1, facecolor='white')
fig.clf()
creat_plot.axl = plt.subplot(111, frameon=False)
plot_node('Decision_node', (0.5, 0.1), (0.1, 0.5), decision_node)
plot_node('Leaf_node', (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.savefig('tree_plot.png')
plt.show()
def get_num_leafs(my_tree):
num_leafs = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]) == dict:
num_leafs += get_num_leafs(second_dict[key])
else:
num_leafs += 1
return num_leafs
def get_tree_depth(my_tree):
max_depth = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]) == dict:
this_depth = 1 + get_tree_depth(second_dict[key])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
def retrieve_tree(i):
list_of_trees = [ {'no surfacing':{0:'no', 1:{'flippers':{0:'no', 1:'yes'}}}}, \
{'no surfacing':{0:'no', 1:{'flippers':{0:{'head':{0:'no', 1:'yes'}}, 1:'n0'}}}}]
return list_of_trees[i]
def plot_mid_text(cntr_pt, parent_pt, txt_string):
x_mid = (parent_pt[0] - cntr_pt[0])/2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1])/2.0 + cntr_pt[1]
creat_plot.axl.text(x_mid, y_mid, txt_string)
def plot_tree(my_tree, parent_pt, node_txt):
num_leafs = get_num_leafs(my_tree)
depth = get_tree_depth(my_tree)
first_str = list(my_tree.keys())[0]
cntr_pt = (plot_tree.xOff + (1.0+float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yOff)
plot_mid_text(cntr_pt, parent_pt, node_txt)
plot_node(first_str, cntr_pt, parent_pt, decision_node)
second_dict = my_tree[first_str]
plot_tree.yOff = plot_tree.yOff - 1.0/plot_tree.totalD
for key in second_dict.keys():
if type(second_dict[key]) == dict:
plot_tree(second_dict[key], cntr_pt, str(key))
else:
plot_tree.xOff = plot_tree.xOff + 1.0/plot_tree.totalW
plot_node(second_dict[key], (plot_tree.xOff, plot_tree.yOff), cntr_pt, leaf_node)
plot_mid_text((plot_tree.xOff, plot_tree.yOff), cntr_pt, str(key))
plot_tree.yOff = plot_tree.yOff + 1.0/plot_tree.totalD
def creat_plot(in_tree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
creat_plot.axl = plt.subplot(111, frameon=False, **axprops)
plot_tree.totalW = float(get_num_leafs(in_tree))
plot_tree.totalD = float(get_tree_depth(in_tree))
plot_tree.xOff = -0.5/plot_tree.totalW
plot_tree.yOff = 1.0
plot_tree(in_tree, (0.5, 1.0), '')
plt.savefig('tree_plotter.png')
plt.show()
def classify(input_tree, feat_labels, test_vec):
first_str = list(input_tree.keys())[0]
second_dict = input_tree[first_str]
feat_index = feat_labels.index(first_str)
for key in second_dict.keys():
if test_vec[feat_index] == key:
if type(second_dict[key]) == dict:
class_label = classify(second_dict[key], feat_labels, test_vec)
else:
class_label = second_dict[key]
return class_label
def store_tree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw)
fw.close()
def grab_tree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
if __name__ == '__main__':
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_tree = decison_tree.create_tree(lenses, lenses_labels)
creat_plot(lenses_tree)