Scikit-Learn 机器学习笔记 – 决策树
import numpy as np
def load_dataset():
from sklearn import datasets
iris = datasets.load_iris()
X = iris['data'][:, (2, 3)]
y = iris['target']
return X, y, iris
def tree_classify(X, y):
from sklearn.tree import DecisionTreeClassifier
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
print(tree_clf.tree_)
return tree_clf
def draw_tree(model, iris):
from sklearn.tree import export_graphviz
export_graphviz(
model,
out_file="iris_tree.dot",
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
def tree_predict(model, sample):
predict = model.predict(sample)
predict_prob = model.predict_proba(sample)
print('决策树预测类别为:', predict, '属于各类别的概率为:', predict_prob)
if __name__ == '__main__':
X, y, iris = load_dataset()
tree_clf = tree_classify(X, y)
tree_predict(tree_clf, [[5, 1.5]])