机器学习——线性回归的自定义实现

一 线性回归(Linear Regression)  

1. 线性回归概述

  回归的目的是预测数值型数据的目标值,最直接的方法就是根据输入写出一个求出目标值的计算公式,也就是所谓的回归方程,例如y = ax1+bx2,其中求回归系数的过程就是回归。那么回归是如何预测的呢?当有了这些回归系数,给定输入,具体的做法就是将回归系数与输入相乘,再将结果加起来就是最终的预测值。说到回归,一般指的都是线性回归,当然也存在非线性回归,在此不做讨论。

  假定输入数据存在矩阵x中,而回归系数存放在向量w中。那么对于给定的数据x1,预测结果可以通过y1 = x1Tw给出,那么问题就是来寻找回归系数。一个最常用的方法就是寻找误差最小的w,误差可以用预测的y值和真实的y值的差值表示,由于正负差值的差异,可以选用平方误差,也就是对预测的y值和真实的y值的平方求和,用矩阵可表示为:(y - xw)T(y - xw),现在问题就转换为寻找使得上述矩阵值最小的w,对w求导为:xT(y - xw),令其为0,解得:w = (xTx)-1xTy,这就是采用此方法估计出来的

注:数据集ex0.txt在此博客的最后面

2.python实现

from numpy import *
%matplotlib inline
import matplotlib.pyplot as plt

def loadDataSet(fileName):
    dataSet=[]
    labels=[]
    fr=open(fileName)
    for line in fr.readlines():
        lineArr=[]
        curLine=line.strip().split('\t')
        for i in range(len(curLine)-1):
            lineArr.append(float(curLine[i]))
        dataSet.append(lineArr)
        labels.append(float(curLine[-1]))
    return dataSet,labels

dataSet,labels=loadDataSet('dataset/ex0.txt')
def standardRegression(dataSet,labels):
    '''
    标准的回归函数的目标是:回归系数ws
    '''
    #1.x^t*x的结果
    xMat=mat(dataSet)
    yMat=mat(labels)
    #print(yMat.T)
    xTx=xMat.T*xMat
    #2.求上面的结果矩阵是否可逆
    if linalg.det(xTx)==0.0:
        print('此矩阵为不可逆矩阵')
        return
    ws=xTx.I*(xMat.T*yMat.T)
    return ws

ws=standardRegression(dataSet,labels)

#求预测值
xMat=mat(dataSet)
yMat=mat(labels).T
predict=xMat*ws
print(predict[0:5])
print(yMat[0:5])
#请加入模型的评测函数
#1.平均绝对误差  Mean Absolute Error,MAE  |y-mean(y)|求和/n
meanValue=mean(yMat)
print('均值:',meanValue)
print('长度:',len(yMat))
#print('差:',yMat-meanValue)
MAE=sum(abs(yMat-mean(yMat)))/len(yMat)
print(MAE)

#2.均方误差 Mean Squared Error ,MSE (y-predict)^2,求和/n
diff=yMat-predict
#print("diff:",diff)
squarel=square(diff)
MSE=sum(squarel)/len(yMat)
print(MSE)

#3.R-squred 1-((y-predict)^2求和)|y-mean(y)|求和
#1-(MSE/MAE)
r=1.0-(sum(square(yMat-predict))/sum(square(yMat-mean(yMat))))
print(r)
#绘制图表来显示  数据点与模型的关系
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8,6))
ax=fig.add_subplot(111)
ax.scatter(xMat[:,1].flatten().A[0],yMat.T[0,:].flatten().A[0])
print(xMat[:,1].flatten().A[0])
xCopy=xMat.copy()
xCopy.sort(0)
predictY=xCopy*ws
ax.plot(xCopy[:,1],predictY)

plt.show()

分析结果我们可以看出线性回归得到的相关性还是挺理想的,但是从图像中明显可以看出线性回归未能捕获到一些数据点,没能很好的表示数据的变化趋势,在某种情况下存在欠拟合的情况,这是线性回归的一个缺点。在此想要说明的一点是,要只是简单的实现拟合的话,不妨采用MATLAB中的cftool的工具,简单高效直观。

二 局部加权线性回归

(Locally Weighted Linear Regression,LWLR)

1.概述

  针对于线性回归存在的欠拟合现象,可以引入一些偏差得到局部加权线性回归对算法进行优化。在该算法中,给待测点附近的每个点赋予一定的权重,进而在所建立的子集上进行给予最小均方差来进行普通的回归,分析可得回归系数w可表示为:

w = (xTWx)-1xTWy,其中W为每个数据点赋予的权重,那么怎样求权重呢,核函数可以看成是求解点与点之间的相似度,在此可以采用核函数,相应的根据预测点与附近点之间的相似程度赋予一定的权重,在此选用最常用的高斯核,则权重可以表示为:w(i,i) = exp(|x(i) - x| / -2k2),其中K为宽度参数,至于此参数的取值,目前仍没有一个确切的标准,只有一个范围的描述,所以在算法的应用中,可以采用不同的取值分别调试,进而选取最好的结果。

2.python实现

结合上述的分析,采用python编程实现,代码如下:

def juxxjq(test,xArr,yArr,k=1.0):
    '''
    本函数给定x空间的任意一个点,计算出对应的预测值yHat
    过程:
    1.读入数据,创建矩阵
    2.创建对角权重矩阵weights ,他是一个方阵,阶数等于样本点个数,即该矩阵为每个样本初始化一个权重
    3.算法将遍历数据集,计算每个样本对应的权重值,随着样本点与待测点距离的递增,权重将以指数级减少
    可以通过k控制速度
    4.在权重矩阵计算完后,就可以对回归系数的一个估计
    '''
    xMat = mat(xArr)
    yMat = mat(yArr).T
    m = shape(xMat)[0]
    weights = mat(eye((m)))#生成m阶矩阵的对角线
    for j in range(m):
        diffMat = test - xMat[j,:]
        weights[j,j]  = exp(diffMat*diffMat.T/(-2.0*k**2))#计算权重
    xTx = xMat.T * (weights * xMat) #计算x^T*w*x
    #判断行列式是否为0,如果为0,则计算逆矩阵时会出错
    if linalg.det(xTx)==0.0:
        print('此矩阵为不可逆矩阵')
        return
    #xTx.I指:(x^T*w*x)^-1
    ws=xTx.I*(xMat.T*(weights*yMat))
    return test * ws
        
#测试一个点
xArr,yArr=loadDataSet('dataset/ex0.txt')
print('原始数值:',yArr[0])
print(juxxjq(xArr[0],xArr,yArr,1.0))
print(juxxjq(xArr[0],xArr,yArr,0.0001))

#测试整个数据集
def juxxjqTest(testArr,xArr,yArr,k=1.0):
    m = shape(testArr)[0]
    yHat = zeros(m)
    for i in range(m):
        yHat[i] = juxxjq(testArr[i],xArr,yArr,k)
    return yHat

yHat=juxxjqTest(xArr,xArr,yArr,0.003)
xMat=mat(xArr)
srtInd=xMat[:,1].argsort(0)#按升序排序,返回下标
xSort=xMat[srtInd][:,0,:]#将xMat按照升序排列
print(yHat[0:5])
#绘制图
import matplotlib.pyplot as plt
fig=plt.figure(figsize=(7,8))
ax=fig.add_subplot(111)
ax.plot(xSort[:,1],yHat[srtInd])
ax.scatter(mat(xArr)[:,1].flatten().A[0],mat(yArr).T.flatten().A[0],s=2,c='red')#s表示红点的大小
plt.show()

试着改变k的值,k=1时权重很大,如同所有数据视为等权重,得到的最佳拟合直线与标准回归,k = 0.003得到的结果更理想。虽然LWLR得到了较为理想的结果,但是此种方法的缺点是在对每个点进行预测时都必须遍历整个数据集,这样无疑是增加了工作量,并且该方法中的的宽度参数的取值对于结果的影响也是蛮大的。同时,当数据的特征比样本点还多当然是用线性回归和之前的方法是不能实现的,当特征比样本点还多时,表明输入的矩阵X不是一个满秩矩阵,在计算(XTX)-1时会出错。

数据集内容:ex0.txt 的内容如下:

1.000000    0.067732    3.176513
1.000000    0.427810    3.816464
1.000000    0.995731    4.550095
1.000000    0.738336    4.256571
1.000000    0.981083    4.560815
1.000000    0.526171    3.929515
1.000000    0.378887    3.526170
1.000000    0.033859    3.156393
1.000000    0.132791    3.110301
1.000000    0.138306    3.149813
1.000000    0.247809    3.476346
1.000000    0.648270    4.119688
1.000000    0.731209    4.282233
1.000000    0.236833    3.486582
1.000000    0.969788    4.655492
1.000000    0.607492    3.965162
1.000000    0.358622    3.514900
1.000000    0.147846    3.125947
1.000000    0.637820    4.094115
1.000000    0.230372    3.476039
1.000000    0.070237    3.210610
1.000000    0.067154    3.190612
1.000000    0.925577    4.631504
1.000000    0.717733    4.295890
1.000000    0.015371    3.085028
1.000000    0.335070    3.448080
1.000000    0.040486    3.167440
1.000000    0.212575    3.364266
1.000000    0.617218    3.993482
1.000000    0.541196    3.891471
1.000000    0.045353    3.143259
1.000000    0.126762    3.114204
1.000000    0.556486    3.851484
1.000000    0.901144    4.621899
1.000000    0.958476    4.580768
1.000000    0.274561    3.620992
1.000000    0.394396    3.580501
1.000000    0.872480    4.618706
1.000000    0.409932    3.676867
1.000000    0.908969    4.641845
1.000000    0.166819    3.175939
1.000000    0.665016    4.264980
1.000000    0.263727    3.558448
1.000000    0.231214    3.436632
1.000000    0.552928    3.831052
1.000000    0.047744    3.182853
1.000000    0.365746    3.498906
1.000000    0.495002    3.946833
1.000000    0.493466    3.900583
1.000000    0.792101    4.238522
1.000000    0.769660    4.233080
1.000000    0.251821    3.521557
1.000000    0.181951    3.203344
1.000000    0.808177    4.278105
1.000000    0.334116    3.555705
1.000000    0.338630    3.502661
1.000000    0.452584    3.859776
1.000000    0.694770    4.275956
1.000000    0.590902    3.916191
1.000000    0.307928    3.587961
1.000000    0.148364    3.183004
1.000000    0.702180    4.225236
1.000000    0.721544    4.231083
1.000000    0.666886    4.240544
1.000000    0.124931    3.222372
1.000000    0.618286    4.021445
1.000000    0.381086    3.567479
1.000000    0.385643    3.562580
1.000000    0.777175    4.262059
1.000000    0.116089    3.208813
1.000000    0.115487    3.169825
1.000000    0.663510    4.193949
1.000000    0.254884    3.491678
1.000000    0.993888    4.533306
1.000000    0.295434    3.550108
1.000000    0.952523    4.636427
1.000000    0.307047    3.557078
1.000000    0.277261    3.552874
1.000000    0.279101    3.494159
1.000000    0.175724    3.206828
1.000000    0.156383    3.195266
1.000000    0.733165    4.221292
1.000000    0.848142    4.413372
1.000000    0.771184    4.184347
1.000000    0.429492    3.742878
1.000000    0.162176    3.201878
1.000000    0.917064    4.648964
1.000000    0.315044    3.510117
1.000000    0.201473    3.274434
1.000000    0.297038    3.579622
1.000000    0.336647    3.489244
1.000000    0.666109    4.237386
1.000000    0.583888    3.913749
1.000000    0.085031    3.228990
1.000000    0.687006    4.286286
1.000000    0.949655    4.628614
1.000000    0.189912    3.239536
1.000000    0.844027    4.457997
1.000000    0.333288    3.513384
1.000000    0.427035    3.729674
1.000000    0.466369    3.834274
1.000000    0.550659    3.811155
1.000000    0.278213    3.598316
1.000000    0.918769    4.692514
1.000000    0.886555    4.604859
1.000000    0.569488    3.864912
1.000000    0.066379    3.184236
1.000000    0.335751    3.500796
1.000000    0.426863    3.743365
1.000000    0.395746    3.622905
1.000000    0.694221    4.310796
1.000000    0.272760    3.583357
1.000000    0.503495    3.901852
1.000000    0.067119    3.233521
1.000000    0.038326    3.105266
1.000000    0.599122    3.865544
1.000000    0.947054    4.628625
1.000000    0.671279    4.231213
1.000000    0.434811    3.791149
1.000000    0.509381    3.968271
1.000000    0.749442    4.253910
1.000000    0.058014    3.194710
1.000000    0.482978    3.996503
1.000000    0.466776    3.904358
1.000000    0.357767    3.503976
1.000000    0.949123    4.557545
1.000000    0.417320    3.699876
1.000000    0.920461    4.613614
1.000000    0.156433    3.140401
1.000000    0.656662    4.206717
1.000000    0.616418    3.969524
1.000000    0.853428    4.476096
1.000000    0.133295    3.136528
1.000000    0.693007    4.279071
1.000000    0.178449    3.200603
1.000000    0.199526    3.299012
1.000000    0.073224    3.209873
1.000000    0.286515    3.632942
1.000000    0.182026    3.248361
1.000000    0.621523    3.995783
1.000000    0.344584    3.563262
1.000000    0.398556    3.649712
1.000000    0.480369    3.951845
1.000000    0.153350    3.145031
1.000000    0.171846    3.181577
1.000000    0.867082    4.637087
1.000000    0.223855    3.404964
1.000000    0.528301    3.873188
1.000000    0.890192    4.633648
1.000000    0.106352    3.154768
1.000000    0.917886    4.623637
1.000000    0.014855    3.078132
1.000000    0.567682    3.913596
1.000000    0.068854    3.221817
1.000000    0.603535    3.938071
1.000000    0.532050    3.880822
1.000000    0.651362    4.176436
1.000000    0.901225    4.648161
1.000000    0.204337    3.332312
1.000000    0.696081    4.240614
1.000000    0.963924    4.532224
1.000000    0.981390    4.557105
1.000000    0.987911    4.610072
1.000000    0.990947    4.636569
1.000000    0.736021    4.229813
1.000000    0.253574    3.500860
1.000000    0.674722    4.245514
1.000000    0.939368    4.605182
1.000000    0.235419    3.454340
1.000000    0.110521    3.180775
1.000000    0.218023    3.380820
1.000000    0.869778    4.565020
1.000000    0.196830    3.279973
1.000000    0.958178    4.554241
1.000000    0.972673    4.633520
1.000000    0.745797    4.281037
1.000000    0.445674    3.844426
1.000000    0.470557    3.891601
1.000000    0.549236    3.849728
1.000000    0.335691    3.492215
1.000000    0.884739    4.592374
1.000000    0.918916    4.632025
1.000000    0.441815    3.756750
1.000000    0.116598    3.133555
1.000000    0.359274    3.567919
1.000000    0.814811    4.363382
1.000000    0.387125    3.560165
1.000000    0.982243    4.564305
1.000000    0.780880    4.215055
1.000000    0.652565    4.174999
1.000000    0.870030    4.586640
1.000000    0.604755    3.960008
1.000000    0.255212    3.529963
1.000000    0.730546    4.213412
1.000000    0.493829    3.908685
1.000000    0.257017    3.585821
1.000000    0.833735    4.374394
1.000000    0.070095    3.213817
1.000000    0.527070    3.952681
1.000000    0.116163    3.129283

猜你喜欢

转载自blog.csdn.net/WJWFighting/article/details/81287627