CenterLoss
CenterLoss是干嘛的?
我们知道:对于图像的分类问题,我们通常使用softmax来计算损失
上图就是我们使用softmax对手写数字图像进行不同数字进行颜色的分类
如果我们在softmax损失的基础上添加上一个centerloss损失,那么对于手写数字图像的分类就如下图所示了
那么centerloss是怎么把图像分类成上面的这个样子的呢?
在训练过程中,我们同时更新中心并最小化深度特征与其对应的类中心之间的距离
通过联合监督,不仅扩大了类间特征差异,而且减少了类内特征差异
至此,我们大体知道了centerloss的作用了,那么centerLoss具体是如何实现的呢?
对于公式的解读:
- c y i c_{y_i} cyi表示第 y i {y_i} yi个类别的特征中心
- x i x_i xi表示特征值
公式的大体意思是:训练中,每一个batch的样本的特征与当前类别中心的距离的平方和越小越好(也就是类内距越小越好)
比如上图:红蓝两个样本,其中X所代表的就是类别的中心点 c y i c_{y_i} cyi
centerloss就是将离散的点向中心点聚拢从而增大了类内的间距
centerLoss和Softmax的结合
代码实现
class CenterLoss(nn.Module):
def __init__(self, num_class, num_feature):
super().__init__()
self.center = nn.Parameter(torch.rand(num_class, num_feature))
def forward(self, features, targets):
batch_size = features.size(0)
loss = torch.sum((features - torch.index_select(self.center,0,targets))**2)/batch_size
return loss