这篇博客用来介绍梯度提升(Gradient boosting),在前向分步算法中,如果我们使用平方损失或者指数损失,那么其每一步的优化过程是很简单的,但是如果采用其他的损失函数呢,往往优化并不容易。因此这里给出一种更加通用的方式去求得最终的模型,即计算损失函数在当前模型上的梯度值,采取梯度下降的方式,使得当前模型逐渐逼近最优模型。
在这里我们基模型还采用决策树(实际上为CART树),并且我们要解决回归问题,采用的损失函数为平方误差损失,其中为当前模型(由前个基模型加和而成)。则损失函数关于在处的负梯度为,因此我们只需让沿着这个方向变化即可,即。因此,我们使用一个回归树来对负梯度值进行拟合,也就是以作为训练样本训练得到一个CART树,则,这样就相当于做了一次梯度下降,依次循环迭代就可以得到最终,而我们往往指定为常数使得最小。
其实可以看出,当损失函数采用平方误差损失时,负梯度所的到的值实际上还是当前模型预测值域真实值的残差,实际上回归树拟合的还是残差。
综上我们给出下图:
下面是sklearn库中GBDT的使用:
# coding: utf-8
# In[8]:
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split, GridSearchCV
import numpy as np
from numpy import random as rd
np.set_printoptions(linewidth=1000, suppress=True)
# In[9]:
def get_and_process_data():
x = load_boston()["data"]
y = load_boston()["target"]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=2)
stand_transfer = StandardScaler()
x_train = stand_transfer.fit_transform(x_train)
x_test = stand_transfer.transform(x_test)
return x_train, x_test, y_train, y_test
# In[10]:
x_train, x_test, y_train, y_test = get_and_process_data()
# In[24]:
def create_model():
# n_estimators指定最终模型由多少颗决策树构成
# learning_rate由于GBDT相当于做梯度下降,这个参数相当于指定梯度下降的学习率
# max_depth指定树的最大深度
# min_sample_leaf指定叶子结点所包含最小样本数
# min_sample_split指定决策树节点少于多少个样本就不在进行分支了
gbdt = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, \
max_depth=4, min_samples_leaf=3, min_samples_split=5)
param = {"n_estimators": np.arange(80, 130, 10), "learning_rate": np.arange(0.1, 0.4, 0.1),\
"max_depth": np.arange(3, 6), "min_samples_leaf": np.arange(3, 6), \
"min_samples_split": np.arange(3, 6)}
gbdt = GridSearchCV(estimator=gbdt, param_grid=param, cv=3)
gbdt.fit(x_train, y_train)
y_test_predict = gbdt.predict(x_test)
print("(预测值, 真实值):", list(zip(y_test_predict, y_test)))
print("均方误差:", mean_squared_error(y_pred=y_test_predict, y_true=y_test))
create_model()