版权声明:学习交流为主,未经博主同意禁止转载,禁止用于商用。 https://blog.csdn.net/u012965373/article/details/83929679
代码模块一、DecisionTreePlot
# -*- coding:utf-8 -*-
__author__ = 'yangxin_ryan'
import matplotlib.pyplot as plt
"""
定义文本框 和 箭头格式
【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】
"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
class DecisionTreePlot(object):
def get_num_leafs(self, my_tree):
num_leafs = 0
first_str = my_tree.keys()[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]) is dict:
num_leafs += self.get_num_leafs(second_dict[key])
else:
num_leafs += 1
return num_leafs
def get_tree_depth(self, my_tree):
max_depth = 0
first_str = my_tree.keys()[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]) is dict:
this_depth = 1 + self.get_tree_depth(second_dict[key])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
def plot_node(self, node_txt, center_pt, parent_pt, node_type):
self.create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', xytext=center_pt,
textcoords='axes fraction', va="center", ha="center", bbox=node_type,
arrowprops=arrow_args)
def plot_mid_text(self, cntr_pt, parent_pt, txt_string):
x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
self.create_plot.ax1.text(x_mid, y_mid, txt_string, va="center", ha="center", rotation=30)
def plot_tree(self, my_tree, parent_pt, node_txt):
num_leafs = self.get_num_leafs(my_tree)
cntr_pt = (self.plot_tree.xOff + (1.0 + float(num_leafs)) / 2.0 / self.plot_tree.totalW, self.plot_tree.yOff)
self.plot_mid_text(cntr_pt, parent_pt, node_txt)
first_str = my_tree.keys()[0]
self.plot_node(first_str, cntr_pt, parent_pt, decisionNode)
second_dict = my_tree[first_str]
self.plot_tree.yOff = self.plot_tree.yOff - 1.0 / self.plot_tree.totalD
for key in second_dict.keys():
if type(second_dict[key]) is dict:
self.plot_tree(second_dict[key], cntr_pt, str(key))
else:
self.plot_tree.xOff = self.plot_tree.xOff + 1.0 / self.plot_tree.totalW
self.plot_node(second_dict[key], (self.plot_tree.xOff, self.plot_tree.yOff), self.cntr_pt, self.leaf_node)
self.plot_mid_text((self.plot_tree.xOff, self.plot_tree.yOff), self.cntr_pt, str(key))
self.plot_tree.yOff = self.plot_tree.yOff + 1.0 / self.plot_tree.totalD
def create_plot(self, in_tree):
fig = plt.figure(1, facecolor='green')
fig.clf()
axprops = dict(xticks=[], yticks=[])
self.create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
self.plot_tree.totalW = float(self.get_num_leafs(in_tree))
self.plot_tree.totalD = float(self.get_tree_depth(in_tree))
self.plot_tree.xOff = -0.5 / self.plot_tree.totalW
self.plot_tree.yOff = 1.0
self.plot_tree(in_tree, (0.5, 1.0), '')
plt.show()
def retrieve_tree(self, i):
list_of_trees = [
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return list_of_trees[i]
代码模块二、DescionTreeApp
# -*- coding:utf-8 -*-
__author__ = 'yangxin_ryan'
import operator
from math import log
from src.descion_tree.decision_tree_plot import DecisionTreePlot as dtPlot
import pickle
import copy
class DescionTreeApp(object):
def create_data_set(self):
data_set = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return data_set, labels
def calc_shannon_ent(self, data_set):
num_entries = len(data_set)
label_counts = {}
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
shannon_ent -= prob * log(prob, 2)
return shannon_ent
def split_data_set(self, data_set, index, value):
ret_data_set = []
for feat_vec in data_set:
if feat_vec[index] == value:
reduced_feat_vec = feat_vec[:index]
reduced_feat_vec.extend(feat_vec[index+1:])
ret_data_set.append(reduced_feat_vec)
return ret_data_set
def choose_best_feature_to_split(self, data_set):
num_features = len(data_set[0]) - 1
base_entropy = self.calc_shannon_ent(data_set)
best_info_gain, best_feature = 0.0, -1
for i in range(num_features):
feat_list = [example[i] for example in data_set]
unique_vals = set(feat_list)
new_entropy = 0.0
for value in unique_vals:
sub_data_set = self.split_data_set(data_set, i, value)
prob = len(sub_data_set)/float(len(data_set))
new_entropy += prob * self.calc_shannon_ent(sub_data_set)
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
def majority_cnt(self, class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def create_tree(self, data_set, labels):
class_list = [example[-1] for example in data_set]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
if len(data_set[0]) == 1:
return self.majority_cnt(class_list)
best_feat = self.choose_best_feature_to_split(data_set)
best_feat_label = labels[best_feat]
my_tree = {best_feat_label: {}}
del(labels[best_feat])
feat_values = [example[best_feat] for example in data_set]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
my_tree[best_feat_label][value] = self.create_tree(self.split_data_set(data_set, best_feat, value), sub_labels)
return my_tree
def classify(self, input_tree, feat_labels, test_vec):
first_str = list(input_tree.keys())[0]
second_dict = input_tree[first_str]
feat_index = feat_labels.index(first_str)
key = test_vec[feat_index]
value_of_feat = second_dict[key]
if isinstance(value_of_feat, dict):
class_label = self.classify(value_of_feat, feat_labels, test_vec)
else:
class_label = value_of_feat
return class_label
def store_tree(self, input_tree, filename):
fw = open(filename, 'wb')
pickle.dump(input_tree, fw)
fw.close()
with open(filename, 'wb') as fw:
pickle.dump(input_tree, fw)
def grab_tree(self, filename):
fr = open(filename, 'rb')
return pickle.load(fr)
# 应用测试二、判断隐形眼镜的类型
def app_contact_lenses(self):
fr = open('')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_tree = self.create_tree(lenses, lenses_labels)
dtPlot.create_plot(lenses_tree)
if __name__ == "__main__":
app = DescionTreeApp()
app.app_contact_lenses()