基于矩阵求导的线性回归

矩阵求导

  1. 考虑矩阵乘法 A ⋅ B = C A \cdot B = C AB=C

  2. 考虑Loss函数 L = ∑ i m ∑ j n ( C i j − p ) 2 L = \sum^m_{i}\sum^n_{j}{(C_{ij} - p)^2} L=imjn(Cijp)2

  3. 考虑C的每一项导数 ▽ C i j = ∂ L ∂ C i j \triangledown C_{ij} = \frac{\partial L}{\partial C_{ij}} Cij=CijL

  4. 考虑ABC都为2x2矩阵时,定义G为L对C的导数
    A = [ a b c d ] B = [ e f g h ] C = [ i j k l ] G = ∂ L ∂ C = [ ∂ L ∂ i ∂ L ∂ j ∂ L ∂ k ∂ L ∂ l ] = [ w x y z ] A = \begin{bmatrix} a & b\\ c & d \end{bmatrix} \quad B = \begin{bmatrix} e & f \\ g & h \end{bmatrix} \quad C = \begin{bmatrix} i & j \\ k & l \end{bmatrix} \quad G = \frac{\partial L}{\partial C} = \begin{bmatrix} \frac{\partial L}{\partial i} & \frac{\partial L}{\partial j} \\ \frac{\partial L}{\partial k} & \frac{\partial L}{\partial l} \end{bmatrix} = \begin{bmatrix} w & x \\ y & z \end{bmatrix} A=[acbd]B=[egfh]C=[ikjl]G=CL=[iLkLjLlL]=[wyxz]

  5. 展开左边 A x B
    C = [ i = a e + b g j = a f + b h k = c e + d g l = c f + d h ] C = \begin{bmatrix} i = ae + bg & j = af + bh\\ k = ce + dg & l = cf + dh \end{bmatrix} C=[i=ae+bgk=ce+dgj=af+bhl=cf+dh]

  6. L对于每一个A的导数 ▽ A i j = ∂ L ∂ A i j \triangledown A_{ij} = \frac{\partial L}{\partial A_{ij}} Aij=AijL

  7. ∂ L ∂ a = ∂ L ∂ i ∗ ∂ i ∂ a + ∂ L ∂ j ∗ ∂ j ∂ a \frac{\partial L}{\partial a} = \frac{\partial L}{\partial i} * \frac{\partial i}{\partial a} + \frac{\partial L}{\partial j} * \frac{\partial j}{\partial a} aL=iLai+jLaj

∂ L ∂ b = ∂ L ∂ i ∗ ∂ i ∂ b + ∂ L ∂ j ∗ ∂ j ∂ b \frac{\partial L}{\partial b} = \frac{\partial L}{\partial i} * \frac{\partial i}{\partial b} + \frac{\partial L}{\partial j} * \frac{\partial j}{\partial b} bL=iLbi+jLbj

∂ L ∂ c = ∂ L ∂ k ∗ ∂ k ∂ c + ∂ L ∂ l ∗ ∂ l ∂ c \frac{\partial L}{\partial c} = \frac{\partial L}{\partial k} * \frac{\partial k}{\partial c} + \frac{\partial L}{\partial l} * \frac{\partial l}{\partial c} cL=kLck+lLcl

∂ L ∂ d = ∂ L ∂ k ∗ ∂ k ∂ d + ∂ L ∂ l ∗ ∂ l ∂ d \frac{\partial L}{\partial d} = \frac{\partial L}{\partial k} * \frac{\partial k}{\partial d} + \frac{\partial L}{\partial l} * \frac{\partial l}{\partial d} dL=kLdk+lLdl
∂ L ∂ a = w e + x f \frac{\partial L}{\partial a} = we + xf \\ aL=we+xf
∂ L ∂ b = w g + x h \frac{\partial L}{\partial b} = wg + xh \\ bL=wg+xh
∂ L ∂ c = y e + z f \frac{\partial L}{\partial c} = ye + zf \\ cL=ye+zf
∂ L ∂ d = y g + z h \frac{\partial L}{\partial d} = yg + zh dL=yg+zh
7. 因此A的导数为 ▽ A = [ w e + x f w g + x h y e + z f y g + z h ] ▽ A = [ w x y z ] [ e g f h ] \triangledown A = \begin{bmatrix} we + xf & wg + xh\\ ye + zf & yg + zh \end{bmatrix} \quad \triangledown A = \begin{bmatrix} w & x\\ y & z \end{bmatrix} \begin{bmatrix} e & g\\ f & h \end{bmatrix} A=[we+xfye+zfwg+xhyg+zh]A=[wyxz][efgh]

▽ A = G ⋅ B T \triangledown A = G \cdot B^T A=GBT
8. 同理B的导数为:
∂ L ∂ e = w a + y c \frac{\partial L}{\partial e} = wa + yc \\ eL=wa+yc
∂ L ∂ f = x a + z c \frac{\partial L}{\partial f} = xa + zc \\ fL=xa+zc
∂ L ∂ g = w b + y d \frac{\partial L}{\partial g} = wb + yd \\ gL=wb+yd
∂ L ∂ h = x b + z d \frac{\partial L}{\partial h} = xb + zd hL=xb+zd
▽ A = [ w a + y c x a + z c w b + y d x b + z d ] ▽ A = [ a c b d ] [ w x y z ] \triangledown A = \begin{bmatrix} wa + yc & xa + zc\\ wb + yd & xb + zd \end{bmatrix} \quad \triangledown A = \begin{bmatrix} a & c\\ b & d \end{bmatrix} \begin{bmatrix} w & x\\ y & z \end{bmatrix} A=[wa+ycwb+ydxa+zcxb+zd]A=[abcd][wyxz]

▽ B = A T ⋅ G \triangledown B = A^T \cdot G B=ATG

Secondhand-house-price-in-Shanghai.csv

房价(元/平米),室,厅,卫,面积(平米),楼层,建成年份
78678,2,1,1,76.26,6,2003
83942,3,2,2,137,11,2003
53942,1,0,1,40.97,6,1989
68929,2,1,1,75.44,6,1995
99163,1,1,1,45.38,6,1996
72442,3,1,1,100.77,24,1998
93637,2,0,1,52.33,6,1995
68605,1,1,1,43,18,2000
63182,1,1,1,34.82,6,1986
57279,1,1,1,41.9,6,1993
50938,2,1,1,74.6,24,2002
65646,1,1,1,41.13,6,1993
65128,3,1,1,78,6,1997
69662,2,1,1,70.34,6,1998
57494,2,2,1,100.88,24,1996
45651,2,1,1,83.24,15,2016
72482,2,2,1,78.64,6,2003
69901,2,1,1,60.8,6,1996
67237,3,1,1,81.8,6,1996
64002,1,1,1,42.03,6,1988
54778,1,0,1,32.86,7,1987
70822,1,1,1,42.36,6,1983
84548,2,2,1,89.89,23,2011
57971,2,1,1,69,6,1995
50000,2,1,1,61,6,1997
57955,1,1,1,44,7,1995
48532,2,1,1,83.45,20,1998
50134,2,2,1,89.76,5,2005
65106,2,1,1,64.51,6,2001
62687,2,2,1,67,6,1998
62687,2,2,1,67,6,1998
67238,2,1,1,59.49,6,1996
58956,1,1,1,35.62,6,1984
28788,2,1,1,66,17,2016
55866,2,1,1,62.65,6,1996
64853,2,1,1,63.22,6,2003
55757,3,1,1,71.74,6,1996
60484,1,1,1,44.64,6,1990
51523,2,2,1,89.28,14,2010
63665,1,1,1,43.98,6,1994
59418,1,1,1,50.49,6,1993
69832,2,1,1,71.6,6,1993
64767,2,1,1,46.32,5,1977
48875,3,2,2,125.83,6,2006
57213,2,1,1,73.41,6,1994
49970,2,1,1,66.04,6,2003
68750,1,1,1,48,24,1995
65763,2,1,1,50.18,6,1987
70964,3,1,1,84.55,6,1996
43820,2,2,1,89,14,2006
57158,1,1,1,38.49,6,1982
67186,3,1,1,65.49,6,1993
67164,2,1,1,80.4,17,2009
83195,1,1,1,30.05,6,1979
61172,2,1,1,62.12,6,1994
59640,3,2,1,106.64,18,2007
69041,2,1,1,63.73,25,1990
56320,3,1,1,92.33,6,2001
60993,2,1,1,68.86,7,1996
68901,3,2,2,117.56,5,1994
55596,1,1,1,41.37,6,1988
65839,1,1,1,40.25,6,1990
59161,1,1,1,43.61,6,1976
67196,2,1,1,49.11,6,1982
59488,2,1,1,84.05,7,1997
67377,1,1,1,31.91,6,1984
51954,2,1,1,84.69,11,2010
60370,3,2,1,74.54,6,1993
57410,2,1,1,56.61,6,1994
60267,2,1,1,75,24,2002
55858,2,2,1,69.82,6,1998
89711,1,0,1,35.67,6,1991
56005,2,1,1,48.21,5,1954
70762,1,1,1,36.46,6,1978
62387,2,1,1,49.69,6,1995
52581,2,2,1,82.73,6,1994
64346,2,1,1,47.4,6,1991
103143,2,1,1,61.08,6,1994
48209,2,2,1,91.27,6,2010
82902,2,2,1,94.69,24,2008
73052,2,1,1,55.44,6,1989
93922,2,1,1,56.43,14,1975
41990,2,1,1,72.16,6,1993
88574,1,1,1,56.45,13,2014
106329,2,1,1,79,6,2003
69901,2,1,1,71.53,6,1994
54884,3,2,1,104.22,11,2013
50786,1,2,1,61.04,14,2007
79091,2,1,1,69.54,6,1995
55402,2,1,1,86.64,6,1998
70919,2,1,1,56.12,6,1990
58133,1,0,1,35.78,6,1990
56976,3,1,1,71.96,6,1994
61050,2,0,1,60.77,18,1989
96978,2,2,2,97.96,8,1998
74224,1,1,1,44.46,6,1993
62822,2,0,1,48.55,6,1993
69532,2,1,1,47.46,6,1984
48527,1,1,1,57.7,6,2012
64017,2,1,1,58.11,6,1994
33910,2,1,1,58.98,6,1995
54897,2,1,1,69.22,6,1996
36111,1,1,1,54,6,2005
51556,2,1,1,56.25,6,2004
78697,2,1,1,64.17,25,1990
63497,1,0,1,30.71,6,1969
61434,3,2,2,130.22,11,2006
64103,1,1,1,43.68,6,1995
60217,1,1,1,50.65,6,1996
67579,2,2,1,61.41,5,1989
83926,2,2,1,101.28,6,1998
65657,2,1,1,49.5,6,1993
90666,2,1,1,48.53,6,1982
71271,1,1,1,36.2,6,1983
83895,3,2,2,145.42,30,2002
64511,2,1,1,63.4,6,1990
98187,1,1,1,39.72,6,1990
34443,3,1,1,86.81,6,2001
87676,1,0,1,40.49,6,1989
63901,2,1,1,57.12,6,1996
91811,2,1,1,67.53,6,1994
56127,1,1,1,42.76,6,1983
83515,2,1,1,105.37,18,2006
68437,2,1,1,73.79,6,1997
83404,2,1,1,70.5,6,1994
103407,2,1,1,75.43,6,1996
67954,2,1,1,68.87,6,1996
63129,2,1,1,50.69,6,1987
69241,2,1,1,61.38,6,1996
68428,1,1,1,39.75,6,1986
87797,3,2,2,119.48,6,2002
73128,1,1,1,45.81,6,1989
56860,1,1,1,54.52,16,2005
84786,2,1,1,63.1,6,1994
43051,1,1,1,52.96,5,2007
71040,2,1,1,84.46,6,1998
104128,2,1,1,53.3,6,1989
59177,2,1,1,51.54,6,1983
71948,2,1,1,78.39,18,1996
65233,1,1,1,41.39,6,1986
41506,2,1,1,77.58,6,1998
71303,2,0,1,56.8,6,1994
66913,2,1,1,67.7,6,1994
64597,1,1,1,55.73,6,1997
76320,1,1,1,48.48,7,1998
50287,2,2,2,99.43,11,2002
59843,1,1,1,36.93,6,1982
79458,1,1,1,42.79,6,1995
60490,1,1,1,33.89,6,1984
90734,2,2,1,103.6,13,2002
60870,1,1,1,34.5,6,1986
62782,2,1,1,50.97,6,1985
54040,2,1,1,56.44,6,1988
54688,1,1,1,32,5,1954
61618,2,1,1,55.99,6,1985
61622,1,1,1,34.89,6,1986
57223,1,1,1,39.32,6,1986
83311,1,1,1,37.81,6,1984
78684,1,1,1,34.95,6,1986
85692,2,1,1,65.35,6,1987
65327,1,1,1,39.8,6,1988
64349,2,1,1,67.6,6,1993
25237,2,2,1,130.76,21,2012
77432,2,1,1,52.95,5,1984
37222,3,1,1,90,6,1994
92563,2,1,1,61.58,6,1990
84242,1,1,1,33,6,1986
69724,1,1,1,37.29,6,1995
89243,2,1,1,91.66,22,1991
73389,3,2,1,99.47,6,2003
85274,2,1,1,58.4,6,1993
67969,2,0,1,52.23,6,1986
61659,2,1,1,71.36,6,1993
48115,2,1,1,87.29,10,1998
66411,1,1,1,39.15,6,1984
98995,2,1,1,65.66,6,1993
102436,1,0,1,36.12,6,1987
57301,1,1,1,47.12,22,2012
58008,1,1,1,39.65,6,1989
69901,3,1,1,72.96,6,1994
79289,1,1,1,43.89,6,1994
67524,2,0,0,62.2,5,1986
67267,1,1,1,42.22,6,1995
65794,2,1,1,87.85,17,2009
62617,2,0,1,53.5,6,1994
82649,2,1,1,62.07,12,2005
81825,2,1,1,57.44,6,1994
55920,2,2,1,85.48,17,2007
79152,4,2,2,145.29,18,2005
70943,1,2,1,50.04,7,1988
73302,2,1,1,61.39,6,1993
81633,2,1,1,63.7,6,1996
50117,2,1,1,93.78,7,1997
54379,1,1,1,51.49,6,1994
68582,2,1,1,53.95,6,1983
63805,1,1,1,34.48,6,1984
52588,1,1,1,48.49,6,1995
72027,2,1,1,59.7,6,1996
62289,2,1,1,56.19,6,1994
35077,2,1,1,65.57,14,2009
65870,1,1,1,40.99,6,1984
83784,1,1,1,37,6,1986
33949,1,1,1,53.02,11,2009
59500,3,3,2,100,14,2005
54537,1,2,1,63.26,16,2008
62165,2,1,1,63.38,6,1995
77548,2,1,1,40.62,6,1988
91820,2,1,1,59.9,11,2006
61338,1,1,1,50.54,24,1995
71077,1,1,1,54.87,25,1990
69758,1,1,1,50.89,28,2004
63449,2,1,1,52.01,6,1987
107037,1,1,1,43.91,6,1992
71884,1,1,1,47.02,6,1996
71949,2,2,1,111.19,11,2003
60680,1,1,1,49.44,6,1995
76394,2,1,1,91.63,6,2003
35400,2,1,1,64.69,6,2012
54531,2,1,1,62.35,6,1995
55762,1,1,1,59.18,17,2010
63837,2,1,1,58.9,6,1996
95254,1,1,1,29.92,6,1989
57368,2,1,1,67.11,6,1994
65837,2,0,1,41.77,4,1979
76453,2,2,2,98.1,6,2002
81412,2,1,1,88.93,6,1995
57803,2,1,1,74.39,6,1998
53606,2,2,1,98.87,6,2001
35410,2,2,1,79.92,11,2009
71097,1,1,1,51.76,6,1996
59894,1,1,1,45.08,6,1997
77685,2,1,1,72.73,7,1996
38286,1,1,1,50.41,28,2014
70000,2,1,1,80,6,2000
53412,2,1,1,71.52,6,1993
99278,2,1,1,55.4,6,1988
70563,2,1,1,55.27,6,1993
56096,1,1,1,44.21,6,1994
78125,2,2,1,80,6,1998
67742,2,1,1,62,6,2004
54605,3,1,1,76,6,1998
65893,1,1,1,31.87,6,1989
71708,1,1,1,30.68,6,1989
60734,1,0,1,37.87,6,1989
77398,1,1,1,56.72,6,2003
90286,1,1,1,43.75,6,1994
35065,2,1,1,77,6,2009
87209,3,1,1,60.2,4,1993
61932,2,1,1,59.42,6,1994
74038,3,1,1,71.72,6,1993
142107,3,2,2,161.85,9,2011
61849,1,1,1,30.72,5,1953
89891,2,1,1,72.31,24,1992
89282,1,1,1,44.69,6,1989
61354,2,2,1,70.9,6,1995
65632,2,2,1,79.23,6,1995
80932,2,1,1,50.66,6,1984
75697,2,1,1,75.3,6,1996
69779,3,1,1,62.34,7,1988
65547,2,2,1,137,25,1998
59175,2,1,1,83.65,6,2006
79795,3,0,1,64.54,4,1980
74557,2,1,1,95.9,28,1995
74042,2,1,1,57.4,6,1984
72947,1,1,1,55.52,24,1993
67445,2,1,1,49.67,6,1984
70943,1,2,1,50.04,7,1988
72069,2,1,1,62.44,6,1993
68811,2,1,1,58.13,6,1990
55627,1,1,1,46.74,6,1994
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle


def get_data(file = "./Secondhand-house-price-in-Shanghai.csv"):
    datas = pd.read_csv(file,names=["y","x1","x2","x3","x4","x5","x6"],skiprows = 1)

    y = datas["y"].values.reshape(-1,1)
    X = datas[[f"x{i}" for i in range(1,7)]].values

    # z-score :  (x - mean_x) / std
    mean_y = np.mean(y)
    std_y = np.std(y)

    mean_X = np.mean(X,axis = 0, keepdims = True)
    std_X =  np.std(X,axis = 0,keepdims = True)

    y = (y - mean_y) / std_y
    X = (X -  mean_X) / std_X

    return X,y,mean_y,std_y,mean_X,std_X


if __name__ == "__main__":
    X,y,mean_y,std_y,mean_X,std_X = get_data()
    K = np.random.random((6,1))

    epoch = 1000
    lr = 0.1
    b = 0

    for e in range(epoch):
        pre = X @ K + b
        loss = np.sum((pre - y)**2)/len(X)

        G = (pre - y ) / len(X)
        delta_K = X.T @ G
        delta_b = np.mean(G)

        K = K - lr * delta_K
        b = b - lr * delta_b

        print(f"loss:{loss:.3f}")

    while True:
        bedroom = (int(input("请输入卧室数量:")))
        ting = (int(input("请输入客厅数量:")))
        wei = (int(input("请输入卫生间数量:")))
        area = (int(input("请输入面积:")))
        floor = (int(input("请输入楼层:")))
        year = (int(input("请输入建成年份:")))

        test_x = (np.array([bedroom, ting, wei, area, floor, year]).reshape(1, -1) - mean_X) / std_X

        p = test_x @ K + b
        print("房价为: ", p * std_y + mean_y)


pre = X @ K + b
loss = np.sum((pre - y)**2)/len(X)

G = (pre - y ) / len(X)
delta_K = X.T @ G
delta_b = np.mean(G)

K = K - lr * delta_K
b = b - lr * delta_b

矩阵求导

上述代码中使用了矩阵求导的方式来计算梯度,这是一种更加高效的方法,可以同时计算多个参数的梯度,减少了循环的次数,提高了计算效率。

具体来说,假设我们有一个输入矩阵X和一个参数向量K,则输出可以表示为:

pre = X @ K + b
其中,@表示矩阵乘法运算。我们的目标是最小化损失函数:

loss = np.sum((pre - y)**2) / len(X)
为了计算参数向量K和b的梯度,我们需要对损失函数进行求导,得到:

dL/dK = X.T @ (pre - y) / len(X)
dL/db = np.mean(pre - y)
其中,X.T表示矩阵X的转置。

然后,我们可以使用梯度下降算法更新参数向量K和b,具体来说:

K = K - lr * dL/dK
b = b - lr * dL/db
其中,lr表示学习率,控制每次参数更新的步长。

使用矩阵求导的方式可以大大减少代码的复杂度,提高计算效率,同时还能够避免一些常见的错误,例如矩阵索引不一致、维度不匹配等问题。

猜你喜欢

转载自blog.csdn.net/qq_44089890/article/details/130003045