loss模块实现
前言
前面两篇博客我们已经把数据的导入以及网络的主体实现了,那么这篇博客主要是为了实现网络的损失。
还是上面这张图,我们可以看到,网络主要使用了两个损失函数,即softmax loss以及triplet loss。那么这篇博客主要包括以下几个部分:
- softmax cross entropy loss
- local distance
- Triplet loss
- deepsupervision
1.global feature的softmax loss
这里的softmax loss其实是指对global feature使用分类问题的方式进行计算,首先使用softmax计算概率,然后根据概率使用交叉熵损失计算相似度。
在utils包下新建losses.py用来存放要用到的损失函数,代码如下:
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss.
"""
def __init__(self, use_gpu=True):
super(CrossEntropyLoss, self).__init__()
self.use_gpu = use_gpu
self.crossentropy_loss = nn.CrossEntropyLoss()
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
if self.use_gpu: targets = targets.cuda()
loss = self.crossentropy_loss(inputs, targets)
return loss
# 防止过拟合
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
2.local distance
在models.py下新建local_dist.py,主要包含以下几个函数:
- enclidean_dist (计算欧式距离)
- batch_enclidean_dist(适用于多通道图片)
- shortest_dist(自定义的计算距离矩阵距离的函数)
- hard_example_mining(求最难样本对 triplet loss用)
- batch_local_dist(计算local distance)
代码如下:
#-*-coding:utf-8-*-
# 实现最短路径的计算及选取最难样本对
import torch
# 计算欧式距离
def enclidean_dist(x,y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m,n = x.size(0),y.size(0)
xx = torch.pow(x,2).sum(1,keepdim=True).expand(m,n)
yy = torch.pow(y,2).sum(1,keepdim=True).expand(m,n)
dist = xx + yy
dist.addmm_(1,-2,x,y.t())
dist = dist.clamp(min=1e-12).sqrt()
return dist
def batch_enclidean_dist(x,y):
"""
Args:
x: pytorch Variable, with shape [Batch size, Local part, Feature channel]
y: pytorch Variable, with shape [Batch size, Local part, Feature channel]
Returns:
dist: pytorch Variable, with shape [Batch size, Local part, Local part]
"""
# 断言 在不满足条件时直接抛出异常
assert len(x.size()) == 3
assert len(y.size()) == 3
assert x.size(0) == y.size(0)
assert x.size(-1) == y.size(-1)
N,m,d = x.size()
N,n,d = y.size()
# shape[N,m,n]
xx = torch.pow(x,2).sum(-1,keepdim=True).expand(N,m,n)
yy = torch.pow(y,2).sum(-1,keepdim=True).expand(N,n,m).permute(0,2,1)
dist = xx + yy
dist.baddbmm_(1,-2,x,y.permute(0,2,1))
dist = dist.clamp(min=1e-12).sqrt()
return dist
# 计算最短路径
def shortest_dist(dist_mat):
"""Parallel version.
Args:
dist_mat: pytorch Variable, available shape:
1) [m, n]
2) [m, n, N], N is batch size
3) [m, n, *], *
can be arbitrary additional dimensions
Returns:
dist: three cases corresponding to `dist_mat`:
1) scalar
2) pytorch Variable, with shape [N]
3) pytorch Variable, with shape [*]
"""
m,n = dist_mat.size()[:2]
# Just offering some reference for accessing intermediate distance.
# m×n次循环 创建数组
dist = [[0 for _ in range(n)] for _ in range(m)]
for i in range(m):
for j in range(n):
if(i==0) and (j==0):
dist[i][j] = dist_mat[i,j]
elif(i == 0) and (j > 0):
dist[i][j] = dist[i][j-1] + dist_mat[i,j]
elif(i>0) and (j == 0):
dist[i][j] = dist[i-1][j] + dist_mat[i,j]
else:
dist[i][j] = torch.min(dist[i-1][j],dist[i][j-1]) + dist_mat[i,j]
dist = dist[-1][-1]
return dist
def batch_local_dist(x,y):
"""
Args:
x: pytorch Variable, with shape [N, m, d]
y: pytorch Variable, with shape [N, n, d]
Returns:
dist: pytorch Variable, with shape [N]
"""
assert len(x.size()) == 3
assert len(y.size()) == 3
assert x.size(0) == y.size(0)
assert x.size(-1) == y.size(-1)
# shape[N,m,n]
dist_mat = batch_enclidean_dist(x,y)
dist_mat = (torch.exp(dist_mat) - 1.)/(torch.exp(dist_mat) + 1.)
# shape[N]
dist = shortest_dist(dist_mat.permute(1,2,0))
return dist
# 寻找最难样本对 dist_mat距离矩阵 labels为标签
# return_inds为是否返回正负样本对id 默认为False
# 但是AlignReID两分支需要使用同一对正负样本对 所以为True
def hard_example_mining(dist_mat,labels,return_inds=False):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
# 断言
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
# 距离矩阵的尺寸是 (batch_size, batch_size) [N,N]
N = dist_mat.size(0)
# shape[N,N] 选出所有正负样本对
is_pos = labels.expand(N,N).eq(labels.expand(N,N).t()) # 两两组合, 取label相同的a-p
is_neg = labels.expand(N,N).ne(labels.expand(N,N).t()) # 两两组合, 取label不同的a-n
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap,relative_p_inds = torch.max(dist_mat[is_pos].contiguous().view(N,-1),1,keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = torch.min(dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# shape [N]
# 去掉维数为1的的维度,squeeze(a)将a中所有为1的维度删掉,不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
# 收集输入的特定维度指定位置的数值(input,dim,index)
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
if __name__ == "__main__":
x = torch.randn(32,2048)
y = torch.randn(32,2048)
dist_mat = enclidean_dist(x,y)
3.triplet loss
依然存放到utils的losses.py文件中:
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.3, mutual_flag = False):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
self.mutual = mutual_flag
def forward(self, inputs, targets):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
targets: ground truth labels with shape (num_classes)
"""
n = inputs.size(0)
# inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
if self.mutual:
return loss, dist
return loss
class TripletAlignedReIDloss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
margin (float): margin for triplet.
"""
def __init__(self,margin=0.3,mutual_flag = False):
super(TripletAlignedReIDloss,self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
self.ranking_loss_local = nn.MarginRankingLoss(margin=margin)
self.mutual = mutual_flag
# input[N,()]
def forward(self,inputs,targets,local_features):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
targets: ground truth labels with shape (num_classes)
"""
n = inputs.size(0)
# 1.
# power是逐元素点乘,sum起来就是取模的平方, keepdim保持维度不变,不要求和成一个数
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t() # a^2 + b^2
# a^2 + b^2 -2*ab
dist.addmm_(1,-2,inputs,inputs.t()) # 做的是如(a1, a2, a, b) -> a1*dist + a2*a*b
# for numerical stability
# 将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量,防止他等于0,再做一个开方
dist = dist.clamp(min=1e-12).sqrt()
dist_ap,dist_an,p_inds,n_inds = hard_example_mining(dist,targets,return_inds=True)
p_inds,n_inds = p_inds.long(),n_inds.long()
local_features = local_features.permute(0,2,1)
p_local_feature = local_features[p_inds]
n_local_feature = local_features[n_inds]
local_dist_ap = batch_local_dist(local_features,p_local_feature)
local_dist_an = batch_local_dist(local_features,n_local_feature)
y = torch.ones_like(dist_an)
global_loss = self.ranking_loss(dist_an,dist_ap,y)
local_loss = self.ranking_loss_local(local_dist_an,local_dist_ap,y)
if self.mutual:
return global_loss+local_loss,dist
return global_loss,local_loss
if __name__ == "__main__":
# 32
target = [1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8]
target = torch.Tensor(target)
features =torch.rand(32,2048)
local_features = torch.randn(32,128,8)
triloss = TripletAlignedReIDloss()
global_loss,local_loss = triloss(features,target,local_features)
embed()
4.loss求和函数
def DeepSupervision(criterion, xs, y):
"""
Args:
criterion: loss function
xs: tuple of inputs
y: ground truth
"""
loss = 0.
for x in xs:
loss += criterion(x, y)
return loss