版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/diyoosjtu/article/details/89141554
CrossEntropyLoss() 函数联合调用了 nn.LogSoftmax() 和 nn.NLLLoss()。
假设网络得到的输出为
h,它的维度大小为
B×C,其中
B 是 batch_size,
C 是分类的总数目。与之对应的训练数据的标签
y 维度是
1×B,
y 中元素的取值范围是
[0,C−1],即
0≤y[j]≤C−1j=0,1,⋯,B−1
我们将CrossEntropyLoss() 函数的计算过程拆解为如下两个步骤:
- 对输出
h,执行LogSoftmax(dim=1),得到
s,维度仍然是
B×C。
- 对
s 执行
−log()操作,得到负对数概率
p,维度仍然是
B×C。
则交叉熵的计算公式为:
L=B1i=0∑B{−log(p[i,y[i]])}(1)
式(1)其实是从式(2)化简得来的:
L=B1i=0∑B{−j=0∑C−1y[i,j]log(p[i,j])}(2)
举例说明:
对于
C=10,
y=[7,7,2,4] 的情况,可知
B=4,首先需要把
y扩展为
B×C 的矩阵:
y=⎣⎢⎢⎡0000000000100000000100000000110000000000⎦⎥⎥⎤
其中为1的元素位置,就是最终概率
p 中需要取值的位置。
网络得到的输出
h=⎣⎢⎢⎡−0.1070−0.0977−0.1049−0.11640.0083−0.0053−0.0091−0.0018−0.0789−0.0613−0.0663−0.07460.03410.05760.06110.05310.06860.06900.07090.0670−0.0088−0.0104−0.0168−0.01420.05400.05580.06020.0700−0.1017−0.1133−0.1072−0.10050.02670.05020.04770.04910.09250.07750.08780.0939⎦⎥⎥⎤
则
s=⎣⎢⎢⎡0.08980.09030.08960.08860.10070.09900.09860.09930.09230.09360.09310.09230.10340.10550.10580.10490.10700.10670.10680.10640.09900.09850.09790.09810.10540.10530.10570.10670.09020.08890.08940.09000.10260.10470.10440.10450.10960.10760.10870.1093⎦⎥⎥⎤
p=⎣⎢⎢⎡2.41072.40482.41232.42422.29542.31242.31652.30962.38262.36842.37372.38242.26962.24952.24632.25472.23512.23812.23652.24082.31252.31752.32422.32202.24972.25132.24722.23782.40542.42042.41462.40832.27702.25692.25972.25872.21122.22962.21962.2139⎦⎥⎥⎤
因此,最终的交叉熵
L=42.4054+2.4204+2.3737+2.2408=2.36