理解Pytorch里面nn.CrossEntropyLoss的含义

理解Pytorch里面nn.CrossEntropyLoss的含义

  • 先说nn.CrossEntropyLoss的参数,如果神经网络的输出output是一个(batch_size, num_class, h, w)的tensor(其中,num_class代表分类问题的类别数,h为图像高度,w为图像宽度),则nn.CrossEntropyLoss需要的label形状为(batch_size, h, w),对每一个batch而言,label中的数据代表每一个像素所属的类别,如果是一个二分类问题,则label中的数值只能是0或者1,如果是三分类问题,则label中的数值可以是0,1,2,以此类推。
  • 交叉熵就是用来衡量两个分布之间的相似性的,因此也可以判断神经网络实际的输出与期望的输出的接近程度。假设有两个分布 p ( x ) , q ( x ) p(x),q(x) ,则两者的交叉熵为
    C E H = x χ p ( x ) l o g ( q ( x ) ) CEH = -\sum_{x\in\chi}p(x)log(q(x))
  • 在分类问题中,给定label和样本之后,该样本只能属于一个种类,假设该样本属于种类 k k ,则 p ( x = k ) = 1 , p ( x k ) = 0 p(x=k)=1, p(x\neq k)=0 ,因此该样本的输出和label的交叉熵可以化简为
    C E H = l o g ( q ( x = k ) ) CEH = -log(q(x=k))
  • 神经网络的输出一般为与类别数相等的向量,为了将向量转换为概率分布,即 q ( x = k ) q(x=k) 的形式,必须使用softmax函数对神经网络的输出进行转换,而加入softmax函数后的交叉熵函数形式如下,该式即为nn.CrossEntropyLoss的公式
    l o s s ( x , k ) = l o g ( e x p ( x [ k ] ) j e x p ( x [ j ] ) ) loss(x,k)=-log\left(\dfrac{exp(x[k])}{\sum_{j}exp(x[j])}\right)
import torch
import numpy as np
import torch.nn as nn
import math

a = torch.randn((4,3,8,8 ))
b = np.random.randint(0,3,(4,8,8))
b = torch.from_numpy(b)
loss_fn = nn.CrossEntropyLoss()
b = b.long()
loss = loss_fn(a, b)
loss
# tensor(1.3822)
#验证softmax2d就是对每一个N维度沿着C维度做softmax
m = nn.Softmax2d()
output = m(a)
#验证softmax2d就是对每一个N维度沿着C维度做softmax
a01 = math.exp(a[0,0,0,0])
a02 = math.exp(a[0,1,0,0])
aa = a01 + a02
print(a01/aa)
print(a02/aa)
print(output[0,0,0,0])
print(output[0,1,0,0])
loss = 0
for batch in range(4):
    for i in range(8):
        for j in range(8):
            if b[batch, i, j] == 1:
                loss = loss - math.log(output[batch, 1, i, j])
            if b[batch, i, j] == 0:
                loss = loss - math.log(output[batch, 0, i, j])
            if b[batch, i, j] == 2:
                loss = loss - math.log(output[batch, 2, i, j])
print(loss/64/4)  #将总的loss对总样本数取平均值,样本数为图像中像素数量8*8再*batch_size即为8*8*4
# 1.3822217100148755
  • 上述结果能够看出,手动计算的loss等于loss_fn计算得到的loss
发布了1 篇原创文章 · 获赞 0 · 访问量 4079

猜你喜欢

转载自blog.csdn.net/lang_yubo/article/details/105108174