class LabelSmoothingLoss(nn.Module):
'''LabelSmoothingLoss
'''
def __init__(self, smoothing=0.05, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_class = pred.size()[-1]
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_class - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
pytorch实现标签平滑
猜你喜欢
转载自blog.csdn.net/qq_55542491/article/details/130882950
今日推荐
周排行