scikit learn 之OLS

scikit learn 之OLS

最近刚刚开始学习Python,之前是对R语言进行的机器学习比较熟悉了,所以算法的数学推导还算理解,主要是针对于Python中的函数进行了学习嘿嘿~希望自己能够输出带动学习!!

代码

from sklearn import linear_model
reg = linear_model.LinearRegression()
reg.fit([[0,0],[1,1],[2,2]],[0,1,2])
reg.coef_
#example
import matplotlib.pyplot as plt
import numpy as np
from sklearn import linear_model,datasets
from sklearn.metrics import mean_squared_error,r2_score

diabetes = datasets.load_diabetes() #获得数据集
diabetes_X = diabetes.data[:,np.newaxis,2] #获取第三列数据

这个语句等同于diabetes.data[:,2][:,np.newaxis],是因为在多元数组的索引中,索引某一列返回的是一行,为了使其变化成一列,我们加入newaxis增加一个维度。参考https://blog.csdn.net/lanchunhui/article/details/49725065

diabetes_X_train = diabetes_X[:-20]#训练集
diabetes_X_test = diabetes_X[-20:]#测试集(后二十条数据)
diabetes_Y_train = diabetes.target[:-20]
diabetes_Y_test = diabetes.target[-20:]
reg = linear_model.LinearRegression()#建立模型对象
reg.fit(diabetes_X_train,diabetes_Y_train)#训练模型
diabetes_Y_predict = reg.predict(diabetes_X_test)#预测测试集
print('Coefficients: \n', reg.coef_)#系数
print('the mean squared error is %.2f'
      %mean_squared_error(diabetes_Y_predict,diabetes_Y_test))#均方误差
print('the r2 of the model is %.2f'%r2_score(diabetes_Y_test,diabetes_Y_predict))#R2
#r2_score(truevalue,predictvalue)注意顺序hhh,不然就是负数啦,均方误差就没有关系啦,毕竟是个平方和~
plt.scatter(diabetes_X_test,diabetes_Y_test,color='black')#散点图
plt.plot(diabetes_X_test,diabetes_Y_predict,color="blue",linewidth=3)#画线
plt.show
plt.xticks(())#把横纵坐标轴变成空的了
plt.yticks(())

猜你喜欢

转载自blog.csdn.net/weixin_43451186/article/details/86220290