基于sciket-learn实现多项式回归

多项式回归在思想上和线性回归是一致的,都使用一条线去拟合样本值,进入用得出的模型去进行预测,在样本特征呈现出线性特性时,我们可以用线性回归去做预测,但是在样本特征很复杂的时候,线性回归往往会呈现出欠拟合的状态,这时就需要多项式回归。

先来看一个小例子,给定一条二次曲线y=2x^2 + 2x,生成带噪声的100个样本点,绘制出图像 ,是我们熟悉的二次方程。

x = np.random.uniform(-3, 3, size = 100)
X = x.reshape(-1, 1)

y = 2 * x**2 + x + 2 + np.random.normal(0, 1, size = 100)

plt.scatter(x, y)
plt.show()

然后我们用线性回归来拟合这条曲线,看看会出现什么情况

from sklearn.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X, y)

y_predict = lin_reg.predict(X)

plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()

很明显我们的预测函数没有很好的拟合这些样本点,当遇到这种情况时,我们不妨在增加一个特征

X2 = np.hstack([X, X**2])

然后同样调用sciket-learn为我们封装好线性回归构造器,接着绘制出图像

lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()

这时,便可以看到,拟合程度已经比较好了。

sciket-learn中为我们提供了PolynomialFeatures来确定特征的维度。

from sklearn.preprocessing import PolynomialFeatures

poly = PolynomialFeatures(degree=2)
poly.fit(X)
X3 = poly.transform(X)

lin_reg3 = LinearRegression()
lin_reg3.fit(X3, y)
y_predict3 = lin_reg3.predict(X3)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict3[np.argsort(x)], color='r')
plt.show()

可以看出得到的图像和上面的图像是一致的,这里有兴趣的朋友可以改变degree参数的值,看看会发生什么样的变化。

完整代码

import numpy as np
import matplotlib.pyplot as plt

x = np.random.uniform(-3, 3, size = 100)
X = x.reshape(-1, 1)
y = 2 * x**2 + x + 2 + np.random.normal(0, 1, size = 100)

plt.scatter(x, y)
plt.show()

from sklearn.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X, y)

y_predict = lin_reg.predict(X)

plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()

# 解决方案 添加一个特征
(X**2).shape
X2 = np.hstack([X, X**2])
X2.shape

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()

lin_reg2.coef_
lin_reg2.intercept_

from sklearn.preprocessing import PolynomialFeatures

poly = PolynomialFeatures(degree=2)
poly.fit(X)
X3 = poly.transform(X)

lin_reg3 = LinearRegression()
lin_reg3.fit(X3, y)
y_predict3 = lin_reg3.predict(X3)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict3[np.argsort(x)], color='r')
plt.show()

猜你喜欢

转载自blog.csdn.net/sinat_33150417/article/details/83619522