deeplearning.ai 总结 - C++实现lstm cell

deeplearning.ai 总结 - C++实现lstm cell

flyfish

//各个变量的含义


    //  xt -- your input data at timestep "t", array of shape(n_x, m).
    //  a_prev -- Hidden state at timestep "t-1", array of shape(n_a, m)
    //  c_prev -- Memory state at timestep "t-1", array of shape(n_a, m)

    //  Wf -- Weight matrix of the forget gate, array of shape(n_a, n_a + n_x)
    //  bf -- Bias of the forget gate, array of shape(n_a, 1)
    //  Wi -- Weight matrix of the update gate, array of shape(n_a, n_a + n_x)
    //  bi -- Bias of the update gate, array of shape(n_a, 1)
    //  Wc -- Weight matrix of the first "tanh", array of shape(n_a, n_a + n_x)
    //  bc --  Bias of the first "tanh", array of shape(n_a, 1)
    //  Wo -- Weight matrix of the output gate, array of shape(n_a, n_a + n_x)
    //  bo --  Bias of the output gate, array of shape(n_a, 1)
    //  Wy -- Weight matrix relating the hidden - state to the output, array of shape(n_y, n_a)
    //  by -- Bias relating the hidden - state to the output, array of shape(n_y, 1)


    //包含个生成随机数的各种方法,可挑选一种使用
    //例如Eigen::MatrixXd  Wf = Eigen::MatrixXd::Random(5,3);是比较简洁的一种方法
    Eigen::Matrix<double, 3, 10> xt;
    xt.setRandom();
    std::cout << "xt1:\n" << xt << std::endl;


    Eigen::Matrix<double, 5, 10>  a_prev;
    a_prev.setRandom();
    std::cout << "a_prev:\n" << a_prev << std::endl;


    Eigen::Matrix<double, 5, 10>  c_prev = Eigen::MatrixXd::Random(5, 10);;
    Eigen::MatrixXd  Wf = Eigen::MatrixXd::Random(5, 5 + 3);


    Eigen::VectorXd bf(5);//相当于Eigen::Matrix<double, 5, 1>  bf;
    bf.setRandom();

    Eigen::Matrix<double, 5, 5 + 3> Wi = Eigen::MatrixXd::Random(5, 5 + 3);
    Eigen::Matrix<double, 5, 1> bi = Eigen::MatrixXd::Random(5, 1);
    Eigen::Matrix<double, 5, 5 + 3> Wo = Eigen::MatrixXd::Random(5, 5 + 3);
    Eigen::Matrix<double, 5, 1> bo = Eigen::MatrixXd::Random(5, 1);
    Eigen::Matrix<double, 5, 5 + 3> Wc = Eigen::MatrixXd::Random(5, 5 + 3);
    Eigen::Matrix<double, 5, 1> bc = Eigen::MatrixXd::Random(5, 1);
    Eigen::Matrix<double, 2, 5> Wy = Eigen::MatrixXd::Random(2, 5);
    Eigen::Matrix<double, 2, 1> by = Eigen::MatrixXd::Random(2, 1);

    Eigen::Matrix<double, 3 + 5, 10>    concat;
    concat << a_prev, xt;
    std::cout << "concat a_prev xt:\n" << concat << std::endl;

    Eigen::MatrixXd ft = sigmond_forward(matrix_add_bias(Wf * concat, bf));
    std::cout << "ft:\n" << ft << std::endl;
    Eigen::MatrixXd it = sigmond_forward(matrix_add_bias(Wi * concat, bi));
    std::cout << "it:\n" << it << std::endl;
    Eigen::MatrixXd cct = (matrix_add_bias(Wc * concat, bc)).array().tanh();
    std::cout << "cct:\n" << cct << std::endl;
    Eigen::MatrixXd c_next = it.cwiseProduct(cct) + ft.cwiseProduct(c_prev);
    std::cout << "c_next:\n" << c_next << std::endl;

    //cwiseProduct()函数允许Matrix直接进行点对点乘法,而不用转换至Array。

    Eigen::MatrixXd ot = sigmond_forward(matrix_add_bias(Wo * concat, bo));

    Eigen::MatrixXd t = (c_next.array().tanh());
    Eigen::MatrixXd a_next = ot.cwiseProduct(t);
    std::cout << "a_next:\n" << a_next << std::endl;


    Eigen::MatrixXd yt_pred = softmax_forward(matrix_add_bias(Wy*a_next, by));
    std::cout << "yt_pred:\n" << a_next << std::endl;

下面包含使用函数的实现

Eigen::MatrixXd matrix_add_bias(const Eigen::MatrixXd & x, const Eigen::VectorXd& b)
{
    int rows = x.rows();
    int cols = x.cols();
    Eigen::MatrixXd res(rows, cols);
    for (int i = 0; i < rows; i++)
    {
        for (int j = 0; j < cols; j++)
        {
            res(i, j) = (x)(i, j) + b(i);
        }
    }
    //std::cout << "matrix_add_bias \n" << res << std::endl;
    return res;
}

Eigen::MatrixXd  sigmond_forward(const Eigen::MatrixXd &x)
{

    int rows = x.rows();
    int cols = x.cols();
    Eigen::MatrixXd res(rows, cols);

    for (int i = 0; i < rows; i++)
    {
        for (int j = 0; j < cols; j++)
        {
            res(i, j) = double(1.0) / double((1.0) + std::exp(-1.0 * x(i, j)));
        }
    }
    return res;

}

Eigen::MatrixXd  softmax_forward(const Eigen::MatrixXd &x)
{

int rows = x.rows();
    int cols = x.cols();
    Eigen::MatrixXd res(rows, cols);


    //数据预处理 
    double max_value=x.array().maxCoeff();
    for (int i = 0; i < rows; i++)
    {

        for (int j = 0; j < cols; j++)
        {
            res(i, j) = x(i, j) - max_value;//各个元素减去矩阵的最大值
        }
    }
    //数据预处理好之后执行公式
    res = res.array().exp();

    Eigen::VectorXd col_sum(cols);//计算各列的和
    for (int j=0;j<cols;j++)
    {
        col_sum(j)=res.col(j).sum();
    }
    for (int i = 0; i < rows; i++)
    {
        for (int j = 0; j < cols; j++)
        {
            res(i, j) = res(i, j) / col_sum(j);//公式
        }

    }

    return res;
}

为了测试C++代码的正确性,用Python写然后全部输出数据,查看数据是否一致

Python代码是 主要deeplearning.ai的代码

import numpy as np

np.random.seed(1);
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def lstm_cell_forward(xt, a_prev, c_prev, parameters):
    """
    Implement a single forward step of the LSTM-cell as described in Figure (4)

    Arguments:
    xt -- your input data at timestep "t", numpy array of shape (n_x, m).
    a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
    c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
    parameters -- python dictionary containing:
                        Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
                        bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
                        Wi -- Weight matrix of the update gate, numpy array of shape (n_a, n_a + n_x)
                        bi -- Bias of the update gate, numpy array of shape (n_a, 1)
                        Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
                        bc --  Bias of the first "tanh", numpy array of shape (n_a, 1)
                        Wo -- Weight matrix of the output gate, numpy array of shape (n_a, n_a + n_x)
                        bo --  Bias of the output gate, numpy array of shape (n_a, 1)
                        Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
                        by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)

    Returns:
    a_next -- next hidden state, of shape (n_a, m)
    c_next -- next memory state, of shape (n_a, m)
    yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
    cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters)

    Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilde),
          c stands for the memory value
    """

    # Retrieve parameters from "parameters"
    Wf = parameters["Wf"]
    bf = parameters["bf"]
    Wi = parameters["Wi"]
    bi = parameters["bi"]
    Wc = parameters["Wc"]
    bc = parameters["bc"]
    Wo = parameters["Wo"]
    bo = parameters["bo"]
    Wy = parameters["Wy"]
    by = parameters["by"]

    # Retrieve dimensions from shapes of xt and Wy
    n_x, m = xt.shape
    n_y, n_a = Wy.shape

    ### START CODE HERE ###
    # Concatenate a_prev and xt (≈3 lines)

    # xt -- your input data at timestep "t", numpy array of shape (n_x, m).
    # a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)

    concat = np.zeros([n_x+n_a,m])
    concat[: n_a, :] = a_prev
    concat[n_a :, :] = xt

    # Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
    ft = sigmoid(np.dot(Wf,concat)+bf)
    print("ft:\n",ft)
    it = sigmoid(np.dot(Wi,concat)+bi)
    print("it:\n",it)
    cct = np.tanh(np.dot(Wc,concat)+bc)
    print("cct:\n",cct)
    c_next = ft*c_prev + it*cct
    print("c_next:\n",c_next)
    ot = sigmoid(np.dot(Wo,concat)+bo)
    print("ot:\n",ot)
    a_next =ot* np.tanh(c_next)
    print("a_next:\n",a_next)

    # Compute prediction of the LSTM cell (≈1 line)
    yt_pred = softmax(a_next)
    print("yt_pred:\n",yt_pred)
    ### END CODE HERE ###

    # store values needed for backward propagation in cache
    cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters)

    return a_next, c_next, yt_pred, cache



#xt = np.random.randn(3,10)
xt= np.array([[ 1.62434536 ,-0.61175641, -0.52817175, -1.07296862,  0.86540763, -2.3015387,1.74481176, -0.7612069,   0.3190391,  -0.24937038],
 [ 1.46210794 ,-2.06014071 ,-0.3224172  ,-0.38405435 , 1.13376944 ,-1.09989127,-0.17242821 ,-0.87785842  ,0.04221375  ,0.58281521],
 [-1.10061918  ,1.14472371 , 0.90159072  ,0.50249434 , 0.90085595 ,-0.68372786,-0.12289023 ,-0.93576943, -0.26788808  ,0.53035547]])

print("xt:\n",xt)
a_prev = np.random.randn(5,10)
print("a_prev:\n",a_prev)
c_prev = np.random.randn(5,10)
print("c_prev:\n",c_prev)
Wf = np.random.randn(5, 5+3)
print("Wf:\n",Wf)

bf = np.random.randn(5,1)
print("bf:\n",bf)
Wi = np.random.randn(5, 5+3)
print("Wi:\n",Wi)
bi = np.random.randn(5,1)
print("bi:\n",bi)
Wo = np.random.randn(5, 5+3)
print("Wo:\n",Wo)
bo = np.random.randn(5,1)
print("bo:\n",bo)
Wc = np.random.randn(5, 5+3)
print("Wc:\n",Wc)
bc = np.random.randn(5,1)
print("bc:\n",bc)
Wy = np.random.randn(2,5)
print("Wy:\n",Wy)
by = np.random.randn(2,1)
print("by:\n",by)

parameters = {"Wf": Wf, "Wi": Wi, "Wo": Wo, "Wc": Wc, "Wy": Wy, "bf": bf, "bi": bi, "bo": bo, "bc": bc, "by": by}

a_next, c_next, yt, cache = lstm_cell_forward(xt, a_prev, c_prev, parameters)
print("a_next[4] = ", a_next[4])
print("a_next.shape = ", c_next.shape)
print("c_next[2] = ", c_next[2])
print("c_next.shape = ", c_next.shape)
print("yt[1] =", yt[1])
print("yt.shape = ", yt.shape)
print("cache[1][3] =", cache[1][3])
print("len(cache) = ", len(cache))

输出数据是

xt:
 [[ 1.62434536 -0.61175641 -0.52817175 -1.07296862  0.86540763 -2.3015387
   1.74481176 -0.7612069   0.3190391  -0.24937038]
 [ 1.46210794 -2.06014071 -0.3224172  -0.38405435  1.13376944 -1.09989127
  -0.17242821 -0.87785842  0.04221375  0.58281521]
 [-1.10061918  1.14472371  0.90159072  0.50249434  0.90085595 -0.68372786
  -0.12289023 -0.93576943 -0.26788808  0.53035547]]
a_prev:
 [[ 1.62434536 -0.61175641 -0.52817175 -1.07296862  0.86540763 -2.3015387
   1.74481176 -0.7612069   0.3190391  -0.24937038]
 [ 1.46210794 -2.06014071 -0.3224172  -0.38405435  1.13376944 -1.09989127
  -0.17242821 -0.87785842  0.04221375  0.58281521]
 [-1.10061918  1.14472371  0.90159072  0.50249434  0.90085595 -0.68372786
  -0.12289023 -0.93576943 -0.26788808  0.53035547]
 [-0.69166075 -0.39675353 -0.6871727  -0.84520564 -0.67124613 -0.0126646
  -1.11731035  0.2344157   1.65980218  0.74204416]
 [-0.19183555 -0.88762896 -0.74715829  1.6924546   0.05080775 -0.63699565
   0.19091548  2.10025514  0.12015895  0.61720311]]
c_prev:
 [[ 0.30017032 -0.35224985 -1.1425182  -0.34934272 -0.20889423  0.58662319
   0.83898341  0.93110208  0.28558733  0.88514116]
 [-0.75439794  1.25286816  0.51292982 -0.29809284  0.48851815 -0.07557171
   1.13162939  1.51981682  2.18557541 -1.39649634]
 [-1.44411381 -0.50446586  0.16003707  0.87616892  0.31563495 -2.02220122
  -0.30620401  0.82797464  0.23009474  0.76201118]
 [-0.22232814 -0.20075807  0.18656139  0.41005165  0.19829972  0.11900865
  -0.67066229  0.37756379  0.12182127  1.12948391]
 [ 1.19891788  0.18515642 -0.37528495 -0.63873041  0.42349435  0.07734007
  -0.34385368  0.04359686 -0.62000084  0.69803203]]
Wf:
 [[-0.44712856  1.2245077   0.40349164  0.59357852 -1.09491185  0.16938243
   0.74055645 -0.9537006 ]
 [-0.26621851  0.03261455 -1.37311732  0.31515939  0.84616065 -0.85951594
   0.35054598 -1.31228341]
 [-0.03869551 -1.61577235  1.12141771  0.40890054 -0.02461696 -0.77516162
   1.27375593  1.96710175]
 [-1.85798186  1.23616403  1.62765075  0.3380117  -1.19926803  0.86334532
  -0.1809203  -0.60392063]
 [-1.23005814  0.5505375   0.79280687 -0.62353073  0.52057634 -1.14434139
   0.80186103  0.0465673 ]]
bf:
 [[-0.18656977]
 [-0.10174587]
 [ 0.86888616]
 [ 0.75041164]
 [ 0.52946532]]
Wi:
 [[ 0.13770121  0.07782113  0.61838026  0.23249456  0.68255141 -0.31011677
  -2.43483776  1.0388246 ]
 [ 2.18697965  0.44136444 -0.10015523 -0.13644474 -0.11905419  0.01740941
  -1.12201873 -0.51709446]
 [-0.99702683  0.24879916 -0.29664115  0.49521132 -0.17470316  0.98633519
   0.2135339   2.19069973]
 [-1.89636092 -0.64691669  0.90148689  2.52832571 -0.24863478  0.04366899
  -0.22631424  1.33145711]
 [-0.28730786  0.68006984 -0.3198016  -1.27255876  0.31354772  0.50318481
   1.29322588 -0.11044703]]
bi:
 [[-0.61736206]
 [ 0.5627611 ]
 [ 0.24073709]
 [ 0.28066508]
 [-0.0731127 ]]
Wo:
 [[ 1.16033857  0.36949272  1.90465871  1.1110567   0.6590498  -1.62743834
   0.60231928  0.4202822 ]
 [ 0.81095167  1.04444209 -0.40087819  0.82400562 -0.56230543  1.95487808
  -1.33195167 -1.76068856]
 [-1.65072127 -0.89055558 -1.1191154   1.9560789  -0.3264995  -1.34267579
   1.11438298 -0.58652394]
 [-1.23685338  0.87583893  0.62336218 -0.43495668  1.40754     0.12910158
   1.6169496   0.50274088]
 [ 1.55880554  0.1094027  -1.2197444   2.44936865 -0.54577417 -0.19883786
  -0.7003985  -0.20339445]]
bo:
 [[ 0.24266944]
 [ 0.20183018]
 [ 0.66102029]
 [ 1.79215821]
 [-0.12046457]]
Wc:
 [[-1.23312074e+00 -1.18231813e+00 -6.65754518e-01 -1.67419581e+00
   8.25029824e-01 -4.98213564e-01 -3.10984978e-01 -1.89148284e-03]
 [-1.39662042e+00 -8.61316361e-01  6.74711526e-01  6.18539131e-01
  -4.43171931e-01  1.81053491e+00 -1.30572692e+00 -3.44987210e-01]
 [-2.30839743e-01 -2.79308500e+00  1.93752881e+00  3.66332015e-01
  -1.04458938e+00  2.05117344e+00  5.85662000e-01  4.29526140e-01]
 [-6.06998398e-01  1.06222724e-01 -1.52568032e+00  7.95026094e-01
  -3.74438319e-01  1.34048197e-01  1.20205486e+00  2.84748111e-01]
 [ 2.62467445e-01  2.76499305e-01 -7.33271604e-01  8.36004719e-01
   1.54335911e+00  7.58805660e-01  8.84908814e-01 -8.77281519e-01]]
bc:
 [[-0.86778722]
 [-1.44087602]
 [ 1.23225307]
 [-0.25417987]
 [ 1.39984394]]
Wy:
 [[-0.78191168 -0.43750898  0.09542509  0.92145007  0.0607502 ]
 [ 0.21112476  0.01652757  0.17718772 -1.11647002  0.0809271 ]]
by:
 [[-0.18657899]
 [-0.05682448]]
ft:
 [[0.93342112 0.01873532 0.31879297 0.03645608 0.70083348 0.34466958
  0.14007722 0.03403372 0.69185905 0.62265514]
 [0.76944697 0.01548017 0.05212295 0.6846647  0.0380791  0.96645775
  0.11998371 0.99169665 0.71083407 0.43393576]
 [0.0096553  0.99579174 0.9807537  0.95426937 0.90749184 0.73438514
  0.20861643 0.2578163  0.60901046 0.9426081 ]
 [0.38869392 0.78367665 0.92568891 0.40404883 0.84818819 0.8742035
  0.13017358 0.05644821 0.65028999 0.84117879]
 [0.12537878 0.48551632 0.8950891  0.98772175 0.77013915 0.97363848
  0.04094597 0.78774353 0.20290836 0.90145966]]
it:
 [[1.56339072e-03 9.96099904e-01 7.42313578e-01 9.05874735e-01
  1.12309515e-01 6.90213408e-01 3.00505508e-01 8.20579645e-01
  3.21278251e-01 3.83345996e-01]
 [9.81012966e-01 5.17414473e-01 3.19542854e-01 1.25933181e-01
  7.73509903e-01 3.68808714e-02 9.91267605e-01 4.44691623e-01
  7.61676248e-01 2.91990904e-01]
 [1.83278673e-01 8.05585226e-01 8.31349370e-01 5.77414603e-01
  8.92894351e-01 1.92603857e-01 3.36918546e-01 1.01540368e-01
  6.34230205e-01 8.55328961e-01]
 [2.84591243e-04 9.93218262e-01 8.81040869e-01 7.62818965e-01
  1.18062012e-01 9.83774288e-01 2.60383249e-03 6.07897288e-01
  9.61623854e-01 9.58539298e-01]
 [9.88527761e-01 1.05985672e-02 3.60990517e-01 5.81138043e-01
  9.44460008e-01 6.72442596e-02 8.17257718e-01 2.30253916e-01
  1.32365782e-01 5.11039172e-01]]
cct:
 [[-0.99948169  0.98478327  0.43025665  0.99937994 -0.99815922  0.99983835
  -0.90887136  0.99884035 -0.99930754 -0.98346529]
 [-0.99981448  0.9972597  -0.63957201 -0.97374353 -0.99872911  0.03899929
  -0.8214443  -0.73883911 -0.47400845 -0.9851213 ]
 [-0.93498283  0.99999984  0.99863728 -0.64181366  0.97259978 -0.90278087
   0.99916331 -0.98760363  0.91674587  0.35773607]
 [ 0.94449743 -0.99940948 -0.9479902  -0.97468636 -0.69147317  0.44003728
  -0.97115684 -0.44742202  0.85028554  0.31577646]
 [ 0.99997551 -0.99993444 -0.99089385  0.7416703   0.93112785 -0.9716276
   0.98763731  0.99977554  0.99895712  0.98765691]]
c_next:
 [[ 0.27862274  0.974343   -0.04484141  0.89257737 -0.25850285  0.892293
  -0.15559838  0.85131693 -0.1234696   0.17413022]
 [-1.56130017  0.53539121 -0.17763525 -0.32672026 -0.75392452 -0.07159854
  -0.67849403  1.17864168  1.19254048 -0.89363616]
 [-0.18530577  0.30324215  0.98717342  0.46550858  1.15486499 -1.6589536
   0.27275746  0.11318372  0.72155802  1.02425993]
 [-0.0861488  -1.14996116 -0.6625203  -0.57782835  0.08655877  0.53693514
  -0.08983124 -0.25067383  0.89687402  1.25278205]
 [ 1.13882241  0.07929859 -0.69361675 -0.19987509  1.2055626   0.00996489
   0.79307479  0.26454537  0.00642439  1.13397909]]
ot:
 [[0.07249733 0.54040145 0.73424974 0.84751397 0.91071347 0.14496133
  0.10515789 0.31305591 0.80770578 0.96738433]
 [0.99795666 0.03916275 0.0369245  0.00454259 0.43548428 0.01760292
  0.98680806 0.35063646 0.95032392 0.17690516]
 [0.03601709 0.39948237 0.38506047 0.67347204 0.01054133 0.99982554
  0.00130693 0.98388906 0.96706885 0.86799202]
 [0.91900679 0.07919111 0.86248002 0.99522256 0.99355306 0.48471955
  0.51145512 0.90443733 0.66585323 0.99066356]
 [0.7687726  0.13571441 0.03891578 0.00629472 0.07126004 0.21235694
  0.42270364 0.53065266 0.99071107 0.48047872]]
a_next:
 [[ 1.96924439e-02  4.05628887e-01 -3.29027460e-02  6.03993044e-01
  -2.30314578e-01  1.03288489e-01 -1.62316145e-02  2.16558564e-01
  -9.92234125e-02  1.66768691e-01]
 [-9.13759893e-01  1.91698281e-02 -6.49096391e-03 -1.43350878e-03
  -2.77614447e-01 -1.25819395e-03 -5.82749126e-01  2.89984356e-01
   7.90066150e-01 -1.26166120e-01]
 [-6.59881685e-03  1.17558399e-01  2.91165199e-01  2.92666167e-01
   8.63713827e-03 -9.29913841e-01  3.47891416e-04  1.10887123e-01
   5.97526168e-01  6.69739123e-01]
 [-7.89760541e-02 -6.47578349e-02 -5.00271371e-01 -5.18595781e-01
   8.57865859e-02  2.37834666e-01 -4.58214583e-02 -2.22086367e-01
   4.75933484e-01  8.41134712e-01]
 [ 6.25794109e-01  1.07394604e-02 -2.33611599e-02 -1.24166589e-03
   5.95266100e-02  2.11604307e-03  2.79046697e-01  1.37196019e-01
   6.36462524e-03  3.90329719e-01]]
yt_pred:
 [[0.19582658 0.2682632  0.19818694 0.31749071 0.16858565 0.2323208
  0.20414163 0.22004507 0.12005679 0.15121714]
 [0.07699782 0.18227372 0.20349115 0.17329962 0.16079722 0.20925895
  0.11585003 0.23681001 0.29214623 0.11281869]
 [0.19074515 0.20111926 0.27404125 0.23255382 0.21408982 0.08267496
  0.20755441 0.19797899 0.24098013 0.25005659]
 [0.17742731 0.16760028 0.12419366 0.10332297 0.23126058 0.26577905
  0.19818961 0.1419092  0.21339007 0.29680722]
 [0.35900314 0.18074354 0.20008701 0.17333287 0.22526672 0.20996623
  0.27426432 0.20325672 0.13342679 0.18910036]]

再把数据填到C++中然看输出

Eigen::Matrix<double, 3, 10> xt;
    xt << 1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, 0.3190391, -0.24937038,
        1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, 0.04221375, 0.58281521,
        - 1.10061918, 1.14472371, 0.90159072, 0.50249434, 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547;

    Eigen::Matrix<double, 5, 10>  a_prev;
    a_prev << 1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, 0.3190391, -0.24937038,
        1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, 0.04221375, 0.58281521,
        -1.10061918, 1.14472371, 0.90159072, 0.50249434, 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547,
        -0.69166075, -0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646, -1.11731035, 0.2344157, 1.65980218, 0.74204416,
        -0.19183555, -0.88762896, -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514, 0.12015895, 0.61720311;

    Eigen::Matrix<double, 5, 10>  c_prev;
    c_prev<<0.30017032, -0.35224985, -1.1425182, -0.34934272, -0.20889423, 0.58662319, 0.83898341, 0.93110208, 0.28558733, 0.88514116,
        -0.75439794, 1.25286816, 0.51292982, -0.29809284, 0.48851815, -0.07557171, 1.13162939, 1.51981682, 2.18557541, -1.39649634,
        -1.44411381, -0.50446586, 0.16003707, 0.87616892, 0.31563495, -2.02220122, -0.30620401, 0.82797464, 0.23009474, 0.76201118,
        -0.22232814, -0.20075807, 0.18656139, 0.41005165, 0.19829972, 0.11900865, -0.67066229, 0.37756379, 0.12182127, 1.12948391,
        1.19891788, 0.18515642, -0.37528495, -0.63873041, 0.42349435, 0.07734007, -0.34385368, 0.04359686, -0.62000084, 0.69803203;


    Eigen::Matrix<double, 5, 5 + 3> Wf;
    Wf <<-0.44712856, 1.2245077 , 0.40349164, 0.59357852, -1.09491185, 0.16938243, 0.74055645, -0.9537006,
        -0.26621851, 0.03261455, -1.37311732, 0.31515939, 0.84616065, -0.85951594, 0.35054598, -1.31228341,
        -0.03869551, -1.61577235, 1.12141771, 0.40890054, -0.02461696, -0.77516162, 1.27375593, 1.96710175,
        -1.85798186, 1.23616403, 1.62765075, 0.3380117, -1.19926803, 0.86334532, -0.1809203, -0.60392063,
        -1.23005814, 0.5505375 , 0.79280687, -0.62353073, 0.52057634, -1.14434139, 0.80186103, 0.0465673;

    Eigen::VectorXd bf(5);//相当于Eigen::Matrix<double, 5, 1>  bf;
    bf << -0.18656977,
        -0.10174587,
        0.86888616,
        0.75041164,
        0.52946532;

    Eigen::Matrix<double, 5, 5 + 3> Wi;
    Wi<<0.13770121, 0.07782113, 0.61838026, 0.23249456, 0.68255141, -0.31011677, -2.43483776, 1.0388246,
        2.18697965, 0.44136444, -0.10015523, -0.13644474, -0.11905419, 0.01740941, -1.12201873, -0.51709446,
        -0.99702683, 0.24879916, -0.29664115, 0.49521132, -0.17470316, 0.98633519, 0.2135339,  2.19069973,
        -1.89636092, -0.64691669, 0.90148689, 2.52832571, -0.24863478, 0.04366899, -0.22631424, 1.33145711,
        -0.28730786, 0.68006984, -0.3198016, -1.27255876, 0.31354772, 0.50318481, 1.29322588, -0.11044703;

    Eigen::Matrix<double, 5, 1> bi;
    bi << -0.61736206,
        0.5627611,
        0.24073709,
        0.28066508,
        -0.0731127;


    Eigen::Matrix<double, 5, 5 + 3> Wo ;
    Wo<<1.16033857, 0.36949272, 1.90465871, 1.1110567 , 0.6590498, -1.62743834, 0.60231928, 0.4202822,
        0.81095167, 1.04444209, -0.40087819, 0.82400562, -0.56230543, 1.95487808, -1.33195167, -1.76068856,
        -1.65072127, -0.89055558, -1.1191154,  1.9560789, -0.3264995, -1.34267579, 1.11438298, -0.58652394,
        -1.23685338, 0.87583893, 0.62336218, -0.43495668, 1.40754 , 0.12910158, 1.6169496 , 0.50274088,
        1.55880554, 0.1094027, -1.2197444 , 2.44936865, -0.54577417, -0.19883786, -0.7003985, -0.20339445;

    Eigen::Matrix<double, 5, 1> bo ;
    bo<<0.24266944,
        0.20183018,
        0.66102029,
        1.79215821,
        -0.12046457;

    Eigen::Matrix<double, 5, 5 + 3> Wc ;
    Wc <<-1.23312074e+00, -1.18231813e+00, -6.65754518e-01, -1.67419581e+00, 8.25029824e-01, -4.98213564e-01, -3.10984978e-01, -1.89148284e-03,
        -1.39662042e+00, -8.61316361e-01, 6.74711526e-01, 6.18539131e-01, -4.43171931e-01, 1.81053491e+00, -1.30572692e+00, -3.44987210e-01,
        -2.30839743e-01, -2.79308500e+00, 1.93752881e+00, 3.66332015e-01, -1.04458938e+00, 2.05117344e+00, 5.85662000e-01, 4.29526140e-01,
        -6.06998398e-01, 1.06222724e-01, -1.52568032e+00, 7.95026094e-01, -3.74438319e-01, 1.34048197e-01, 1.20205486e+00, 2.84748111e-01,
        2.62467445e-01, 2.76499305e-01, -7.33271604e-01, 8.36004719e-01, 1.54335911e+00, 7.58805660e-01, 8.84908814e-01, -8.77281519e-01;


    Eigen::Matrix<double, 5, 1> bc;
    bc << -0.86778722,
        -1.44087602,
        1.23225307,
        -0.25417987,
        1.39984394;

    Eigen::Matrix<double, 2, 5> Wy ;
    Wy << -0.78191168, -0.43750898, 0.09542509, 0.92145007, 0.0607502,
        0.21112476, 0.01652757, 0.17718772, -1.11647002, 0.0809271;


    Eigen::Matrix<double, 2, 1> by = Eigen::MatrixXd::Random(2, 1);
    by <<-0.18657899,
    - 0.05682448;


    Eigen::Matrix<double, 3 + 5, 10>    concat;
    concat << a_prev, xt;
    std::cout << "concat a_prev xt:\n" << concat << std::endl;

    Eigen::MatrixXd ft = sigmond_forward(matrix_add_bias(Wf * concat, bf));
    std::cout << "ft:\n" << ft << std::endl;
    Eigen::MatrixXd it = sigmond_forward(matrix_add_bias(Wi * concat, bi));
    std::cout << "it:\n" << it << std::endl;
    Eigen::MatrixXd cct = (matrix_add_bias(Wc * concat, bc)).array().tanh();
    std::cout << "cct:\n" << cct << std::endl;
    Eigen::MatrixXd c_next = it.cwiseProduct(cct) + ft.cwiseProduct(c_prev);
    std::cout << "c_next:\n" << c_next << std::endl;
    Eigen::MatrixXd ot = sigmond_forward(matrix_add_bias(Wo * concat, bo));
    std::cout << "ot:\n" << ot << std::endl;

    Eigen::MatrixXd t = (c_next.array().tanh());
    Eigen::MatrixXd a_next = ot.cwiseProduct(t);
    std::cout << "a_next:\n" << a_next << std::endl;

    Eigen::MatrixXd yt_pred = softmax_forward(a_next);
    std::cout << "yt_pred:\n" << yt_pred << std::endl;

结果是C++代码与Python代码输出结果是一致的

猜你喜欢

转载自blog.csdn.net/flyfish1986/article/details/80142437