softmax
- 计算exponential
- 按行求和
- 每一行都要除以计算的和
m = np.random.rand(10,10) * 10 + 1000
print(m)
[[ 1008.64304012 1001.25079229 1006.81896868 1005.89015258 1008.8915297
1001.84923866 1005.53509734 1005.34075305 1008.93404709
1006.94897664]
[ 1003.24267825 1003.72710741 1000.28354398 1000.32012105 1004.3690361
1007.18390602 1002.49741606 1005.83510332 1009.19678396
1002.32098566]
[ 1002.32824002 1006.2813999 1009.27645662 1002.57259159
1006.30743627 1000.35201323 1003.94430099 1008.79056869
1007.40485841 1006.38239542]
[ 1007.06228714 1006.01325352 1007.96901864 1002.34269542
1000.75563221 1005.26357317 1006.14861174 1005.68119044
1000.69006453 1007.21834125]
[ 1004.15770428 1003.0554848 1005.55619032 1003.04000025
1005.54338468 1002.23952638 1008.86317857 1006.96983789
1005.84232318 1009.28833837]
[ 1008.47151667 1006.30354927 1006.69274016 1004.12418543
1007.17550972 1004.31758292 1007.27760499 1007.45250445
1000.02943239 1002.25886446]
[ 1000.63764781 1003.39894276 1008.26298759 1001.89295012
1007.85388369 1004.67565255 1004.58872708 1003.24488815
1000.39528914 1007.20964465]
[ 1005.21815308 1007.42651355 1006.32407717 1003.0096329 1005.03545902
1008.85925437 1009.57634418 1003.74546024 1003.40512867 1004.4437606 ]
[ 1001.78786625 1008.73282377 1003.98906267 1008.17533941
1002.79957584 1000.89332666 1007.64343999 1003.88248211
1005.75517566 1008.27556001]
[ 1002.05916059 1007.25663392 1009.48655775 1009.56831564
1008.28488062 1004.92593854 1008.0468565 1007.53278621
1001.94935121 1007.01473574]]
m_row_max = m.max(axis=1).reshape(10,1)
print(m_row_max, m_row_max.shape)
[[ 1008.93404709]
[ 1009.19678396]
[ 1009.27645662]
[ 1007.96901864]
[ 1009.28833837]
[ 1008.47151667]
[ 1008.26298759]
[ 1009.57634418]
[ 1008.73282377]
[ 1009.56831564]] (10, 1)
m = m - m_row_max
print(m)
[[-0.29100696 -7.6832548 -2.11507841 -3.04389451 -0.04251738 -7.08480843
-3.39894975 -3.59329403 0. -1.98507045]
[-5.95410571 -5.46967655 -8.91323998 -8.87666291 -4.82774786 -2.01287794
-6.6993679 -3.36168065 0. -6.8757983 ]
[-6.9482166 -2.99505673 0. -6.70386503 -2.96902035 -8.9244434
-5.33215563 -0.48588793 -1.87159821 -2.8940612 ]
[-0.9067315 -1.95576512 0. -5.62632322 -7.21338643 -2.70544547
-1.82040689 -2.2878282 -7.27895411 -0.75067739]
[-5.13063409 -6.23285357 -3.73214805 -6.24833812 -3.74495369 -7.04881199
-0.4251598 -2.31850048 -3.44601518 0. ]
[ 0. -2.16796739 -1.77877651 -4.34733123 -1.29600694 -4.15393374
-1.19391168 -1.01901222 -8.44208427 -6.21265221]
[-7.62533977 -4.86404483 0. -6.37003747 -0.4091039 -3.58733504
-3.67426051 -5.01809944 -7.86769844 -1.05334294]
[-4.3581911 -2.14983063 -3.25226701 -6.56671127 -4.54088515 -0.71708981
0. -5.83088393 -6.1712155 -5.13258358]
[-6.94495753 0. -4.7437611 -0.55748437 -5.93324793 -7.83949711
-1.08938379 -4.85034166 -2.97764811 -0.45726376]
[-7.50915505 -2.31168172 -0.08175789 0. -1.28343501 -4.6423771
-1.52145914 -2.03552943 -7.61896443 -2.5535799 ]]
m_exp = np.exp(m)
print(m_exp, m_exp.shape)
[[ 7.47510474e-01 4.60473707e-04 1.20623832e-01 4.76489585e-02
9.58373807e-01 8.37735258e-04 3.34083387e-02 2.75075701e-02
1.00000000e+00 1.37370936e-01]
[ 2.59516363e-03 4.21259451e-03 1.34595041e-04 1.39609277e-04
8.00452828e-03 1.33603617e-01 1.23169021e-03 3.46769303e-02
1.00000000e+00 1.03247308e-03]
[ 9.60346309e-04 5.00337887e-02 1.00000000e+00 1.22616357e-03
5.13535942e-02 1.33095533e-04 4.83363924e-03 6.15150742e-01
1.53877536e-01 5.53509641e-02]
[ 4.03842028e-01 1.41456204e-01 1.00000000e+00 3.60179403e-03
7.36658284e-04 6.68405420e-02 1.61959837e-01 1.01486632e-01
6.89906753e-04 4.72046684e-01]
[ 5.91281003e-03 1.96383995e-03 2.39413532e-02 1.93366498e-03
2.36367236e-02 8.68440058e-04 6.53665320e-01 9.84210597e-02
3.18723893e-02 1.00000000e+00]
[ 1.00000000e+00 1.14409931e-01 1.68844601e-01 1.29413038e-02
2.73622203e-01 1.57025251e-02 3.03033573e-01 3.60951306e-01
2.15600312e-04 2.00391561e-03]
[ 4.87929429e-04 7.71919784e-03 1.00000000e+00 1.71209508e-03
6.64245216e-01 2.76719769e-02 2.53681582e-02 6.61709096e-03
3.82914649e-04 3.48769883e-01]
[ 1.28015234e-02 1.16503889e-01 3.86864061e-02 1.40641513e-03
1.06639631e-02 4.88170860e-01 1.00000000e+00 2.93548106e-03
2.08869564e-03 5.90129429e-03]
[ 9.63481254e-04 1.00000000e+00 8.70584098e-03 5.72647825e-01
2.64986143e-03 3.93867063e-04 3.36423738e-01 7.82570334e-03
5.09124334e-02 6.33013352e-01]
[ 5.48043962e-04 9.90944624e-02 9.21495038e-01 1.00000000e+00
2.77083877e-01 9.63476759e-03 2.18392989e-01 1.30611314e-01
4.91050084e-04 7.78026413e-02]] (10, 10)
m_exp_row_sum = m_exp.sum(axis=1).reshape(10,1)
print(m_exp_row_sum, m_exp_row_sum.shape)
[[ 3.07374213]
[ 1.1856312 ]
[ 1.93291987]
[ 2.35266029]
[ 1.8422156 ]
[ 2.25172496]
[ 2.08297446]
[ 1.67915853]
[ 2.6135361 ]
[ 2.73515418]] (10, 1)
m_softmax = m_exp / m_exp_row_sum
print(m_softmax)
[[ 2.43192319e-01 1.49808829e-04 3.92433154e-02 1.55019376e-02
3.11793823e-01 2.72545719e-04 1.08689465e-02 8.94921207e-03
3.25336336e-01 4.46917571e-02]
[ 2.18884559e-03 3.55303952e-03 1.13521845e-04 1.17751015e-04
6.75128005e-03 1.12685645e-01 1.03884767e-03 2.92476533e-02
8.43432594e-01 8.70821452e-04]
[ 4.96837103e-04 2.58850817e-02 5.17352020e-01 6.34358200e-04
2.65678857e-02 6.88572427e-05 2.50069303e-03 3.18249479e-01
7.96088542e-02 2.86359331e-02]
[ 1.71653353e-01 6.01260646e-02 4.25050742e-01 1.53094522e-03
3.13117150e-04 2.84106220e-02 6.88411489e-02 4.31369681e-02
2.93245377e-04 2.00643793e-01]
[ 3.20961891e-03 1.06602069e-03 1.29959562e-02 1.04964098e-03
1.28305957e-02 4.71410652e-04 3.54825635e-01 5.34253752e-02
1.73011179e-02 5.42824629e-01]
[ 4.44103973e-01 5.08099049e-02 7.49845582e-02 5.74728445e-03
1.21516707e-01 6.97355379e-03 1.34578414e-01 1.60299909e-01
9.57489550e-05 8.89946885e-04]
[ 2.34246477e-04 3.70585333e-03 4.80082698e-01 8.21947224e-04
3.18892636e-01 1.32848374e-02 1.21788138e-02 3.17675088e-03
1.83830698e-04 1.67438386e-01]
[ 7.62377296e-03 6.93823048e-02 2.30391624e-02 8.37571382e-04
6.35077805e-03 2.90723509e-01 5.95536385e-01 1.74818578e-03
1.24389425e-03 3.51443547e-03]
[ 3.68650448e-04 3.82623373e-01 3.33105824e-03 2.19108442e-01
1.01389892e-03 1.50702744e-04 1.28723586e-01 2.99429701e-03
1.94802870e-02 2.42205704e-01]
[ 2.00370409e-04 3.62299365e-02 3.36907895e-01 3.65610102e-01
1.01304664e-01 3.52256836e-03 7.98466830e-02 4.77528160e-02
1.79532871e-04 2.84454316e-02]]
print(m_softmax.sum(axis=1))
[ 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]