1.前言
分类与回归树(classification and regression tree,CART) 模型由Breiman等人在1984年提出,即可应用分类也可用于回归的一个树模型。CART是在给定输入随机变量
条件下输出随机变量
的条件概率分布的学习方法。
本文简单介绍了回归树的算法描述,辅以简单的例子以加深理解。
公式编辑技巧:行内公式:$公式$,块公式:$$公式$$,加粗:**符号**
2.回归树
决策树实际上就是用超平面对空间进行划分的一种方式,每次划分时,都是将结点的数据集一分为二,根据相应的决策方法,一步一步的进行延伸,即基于某种决策递归的构建二叉树的过程。
2.1.原理
假设
与
分别为输入和输出变量,给定训练数据集:
一个回归树的生成对应着对输入空间
的划分,以及在划分的单元上的输出值。假设已将输入空间划分为
个单元
,并且在每个单元上都有一个固定的输出值
,于是回归树模型可以表示为:
当输入的数据集划分确定后,我们可以使用平方误差
来表示回归树对于训练样本的预测误差,因此可以使用平方误差最小化的原则来求解每个单元上的最优输出值,已知单元上
是对应的
上的所有输入数据
所对应的输出值
的均值,即:
然而在分类树中,我们常常采用信息熵等方法对输入空间进行分类,然而在回归树中我们采用启发式(依靠经验)的方法。随机性的选择一个
对应的变量
和他的取值s,作为划分的一个切分点,即将两个输入空间切分成两个区域:
,和
,然后遍历所有特征,并获取其对应的值,找到最优的特征
和对应的
,从而使得损失函数最小,即求解:
即可找到最优输入变量
及其所对应的
.
2.2算法步骤(最小二乘回归树)
step1: 在输入的数据集中启发性的选择一个变量 及其对应的值,进一步将输入空间划分为两个区域;
step2: 用选定的特征
以及其对应的
划分区域并计算其对应的输出值:
step3: 分别计算划分的两个区域的平方误差;
step4: 递归的遍历所有数据的特征,找到最优的
和
,求解:
step5: 对划分后的子区域
和
,再重复step1,step2,step3,直到满足所设定的条件为止;
step6: 将输入的区间划分为
个区域,生成决策树;
啥也不说了,直接上例子了
下表是烟台市近几年的年平均降水量,其中(1,10)分别代表(2007,2018)
分析如下图所示:
将2019年带入预测的模型分别在深度为1和10的情况下得到了675.5mm和690.0mm,可见深度对回归树的影响甚大。
下图为加入线性回归预测后:
预测2019年大约降水量为291.6mm.
代码如下:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from sklearn.tree import DecisionTreeRegressor as dtr
from sklearn import linear_model
matplotlib.rcParams['font.sans-serif']=['SimHei']
x = np.array(list(range(1,11))).reshape(-1,1)
x_1 = (2008,2009,2010,2011,2012,2013,2014,2015,2016,2017)
y = [620,628,640,630,637,641,661,667,684,690]
model1 = dtr(max_depth = 1)
model2 = dtr(max_depth = 10)
model3 = linear_model.LinearRegression()
model1.fit(x,y)
model2.fit(x,y)
model3.fit(x,y)
X_test = np.arange(0.0,10.0,0.01)[:,np.newaxis]
y_1 = model1.predict(X_test)
y_2 = model2.predict(X_test)
y_3 = model3.predict(X_test)
y_4 = model3.predict(11)
print(y_4)
plt.figure()
plt.scatter(x,y,s=20,edgecolor="black",c="darkorange",label="数据")
plt.plot(X_test,y_1,color="cornflowerblue",label="max_depth=1",linewidth=2)
plt.plot(X_test,y_2,color="yellowgreen",label="max_depth=10",linewidth=2)
plt.plot(X_test,y_3,color="red",label="regression",linewidth=2)
plt.xlabel("数据")
plt.ylabel("降雨量")
plt.title("回归树")
plt.legend()
plt.show()
本人小白一枚,文章欠妥之处还望指正!!
参考
本文参考了博主一个拉风的名字的文章。