1、CrossEntropy的weight的问题
在使用torch.nn.CrossEntropyLoss时,其中有个参数是weight,官方给出的文档中显示,weight是a manual rescaling weight given to each class,也就是一个缩放的尺度。但我发现,当另一个参数的reduction为mean的时候,会出现一些问题。
import torch
import torch.nn
pred = torch.tensor([[1,5],[2,2]]).float()
label = torch.tensor([0,1]0.long()
c = nn.CrossEntropyLoss(reduction='none')
print(c(pred,label))
c = nn.CrossEntropyLoss()
print(c(pred,label))
c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]))
print(c(pred,label))
c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]),reduction='none')
print(c(pred,label))
c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]),reduction='none')
print(c(pred,label).mean())
输出为:
tensor([4.0181, 0.6931])
tensor(2.3556)
tensor(2.3556)
tenwor([8.0363, 1.3863])
tensor([4.7113])
如上可以看到,当有weight和reduction='mean’的时候,不是说weight不起作用,而是loss的值不单单是使用weight做scale变化。要是想单纯的做scale,需要先将reduction='none’然后再取mean