损失函数设计的重要性可想而知,设计出一个好的目标函数,不但可以训练出一个更精准的模型,同时可以加快模型模型训练。在接触目标检测之初,对于自己一个不小的挑战就是理解损失函数,这部分内容比较繁琐。我们来简单地看一看 YOLOv1 损失函数设计。
更详细内容可以参考 目标检测系列—深度解读 YOLOv1 (1)
#导入依赖库
from matplotlib.pyplot import cla
import torch
import torch.nn as nn
from utils import intersection_over_union
def __init__(self,S=7,B=2,C=20):
super(YOLOLoss,self).__init__()
self.mse = nn.MSELoss(reduction="sum")
# 网络的数量
self.S = S
# 每个网络输出边界框数量
self.B = B
# 每个网格预测类别数
self.C = C
self.lambda_noobj = 0.5
# 对定位损失添加权重
self.lambda_coord = 5
- 在预测边界框中是否存在目标的置信度,因为大部分边界框都是不包含目标,所以对于没有目标网格置信度概率损失减少其在损失函数所占比重
self.lambda_noobj = 0.5
- 对定位损失添加权重,增加定位损失权重
self.lambda_coord = 5
详细解读代码
predictions = predictions.reshape(-1,self.S,self.S,self.C + self.B*5)
输入 shape 为 (batch_size,features) 转换为 (batch_size,cellSize,cellSize,(class_numbers + bbox_numbers * (confidence + x1,y1,w,h)))
iou_b1 = intersection_over_union(preds[...,(self.C + 1):(self.C + 5)],target[...,(self.C + 1):(self.C + 5)])
iou_b2 = intersection_over_union(preds[...,(self.C + 6):(self.C + 10)],target[...,(self.C + 1):(self.C + 5)])
在预测 self.C + 1 到 self.C + 5 为第一个预测框中,中心坐标和边界框宽高,而 self.C + 6 到 self.C + 10,为第二个边界框的位置信息,这两边界框分别和标注对应网格位置的边界框做交并比,注意可能边界框并没有目标
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)],dim=0)
输出 iou_b1 tensor 为标量,为标量增加一个维度,然后再 0 维度上对 iou_b1 和 iou_b2 进行堆叠
# a 和 b 分别是 intersection_over_union 函数输出交并比的值,然后使用 unsqueeze 扩展维度 [0.6] 和 [0.75], 然后对其进行堆叠
a = torch.tensor(0.6)
b = torch.tensor(0.75)
c = torch.cat([a.unsqueeze(0),b.unsqueeze(0)])
c
iou_maxes, bestbox = torch.max(ious,dim=0)
比较两个边界框 IoU 返回值 iou_maxes 返回 IoU 最大的值 bestbox 返回最大值对应索引
iou_maxes, bestbox = torch.max(c,dim=0)
print(iou_maxes,bestbox) #tensor(0.7500) tensor(1)
对于 target 结构为 (batch_size,cellSize,cellSize,(class_numbers + confidence + (x1,y1,w,h))),
exists_box = target[...,self.C].unsqueeze(3)
a = torch.randint(0,2,size=(2,7,7,85))
exists_box = a[...,80]
exists_box.shape # torch.Size([2, 7, 7])
exists_box = exists_box.unsqueeze(3)
exists_box.shape #torch.Size([2, 7, 7, 1])
exists_box
对于解读神经网络源码时,需要清楚地了解每一个步骤输入和输出数据的 shape
box_preds = exists_box * (
(bestbox * preds[...,(self.C + 6):(self.C + 10)] + (1-bestbox)*preds[...,(self.C + 1):(self.C + 5)])
)
位置损失
位置损失包括两个部分,一个是中心点误差,另一个是宽高的位置损失。这里 是一个符号函数,也就是当有目标是为 1 否则为 0,也就是只计算存在目标的网格预测边界框和真实边界框的中心点和宽度和高度的差值。
这里 bestbox 是取值为 0 或者 1 ,表示第一个 bbox 还是第二个 bbox 是与 target 的 Iou 的值最大,所以 bestbox 为 0 也就是保留第一个 bbox,所以 (1 - bestbox)=1 那么也就是保留 preds[...,(self.C + 1):(self.C + 5)]
反之亦然。
这里 exists_box
表示只计算有目标的网格的位置损失。
box_preds[...,2:4] = torch.sign(box_preds[...,2:4]) * torch.sqrt(torch.abs(box_preds[...,2:4] + 1e-6))
box_targets[...,2:4] = torch.sqrt(box_targets[...,2:4])
在 pytorch 对于一个负数进行开方运算会得到一个 nan
a = torch.tensor(-1.0)
torch.sqrt(a) #
所以这里先对其取绝对值,然后再进行开方运算后再用 torch.sign(a)
将符号还给输出结果
a = torch.tensor(-1.0)
torch.sign(a)*torch.sqrt(torch.abs(a)) #tensor(-1.)
用均方根来计算位置损失
box_loss = self.mse(
#(n,S,S,4)->(N*S*S,4)
torch.flatten(box_preds,end_dim=-2),
torch.flatten(box_targets,end_dim=-2),
)
a = torch.randn((2,7,7,4))
torch.flatten(a,end_dim=-2).shape #torch.Size([98, 4])
目标损失
每一个网格负责预测目标,首先会给出一个网格中是否包含目标置信度。YOLO 系列这样的 one stage 目标检测属于 dense 预测,也就是每个网格都会产生候选框,可能之前也叫做预测框。
网格置信度目标损失也分为 2 个部分进行计算,一个有目标的另一个是没有目标,通过通过权重
pred_box = (
bestbox * preds[...,(self.C + 5):(self.C + 6)] + (1-bestbox) * preds[...,(self.C ):(self.C + 1)]
)
首先计算 pred_box 也就是对于每一个网格保留其置信度较高的概率值。
# (N*S*S,1)
# 类别为,
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[...,(self.C):(self.C + 1)])
)
然后计算有目标网格置信度概率值与这是概率值,这里应该就是 1.0 概率之间差,下面是计算不存在目标网格置信度损失值,这个时候需要考虑到两个 bbox 的置信度分别和真实值也就是 0 做平方差来计算损失值。
# (N,S,S,1) -> (N,S*S)
#
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * preds[...,(self.C ):(self.C + 1)], start_dim=1),
torch.flatten((1 - exists_box) * target[..., (self.C ):(self.C + 1)],start_dim=1)
)
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * preds[...,(self.C + 5):(self.C + 6)], start_dim=1),
torch.flatten((1 - exists_box) * target[..., (self.C ):(self.C + 1)],start_dim=1)
)
类别损失
类别损失相对比较简单,所以这里就不做过多解释了。
class_loss = self.mse(
torch.flatten(exists_box*predictions[...,:self.C],end_dim=-2),
torch.flatten(exists_box*target[...,:self.C],end_dim=-2)
)
loss = (
self.lambda_coord * box_loss
+ object_loss
+ self.lambda_noobj * no_object_loss
+ class_loss
)
return loss
最后需要将位置损失、目标损失和类别损失相加求和得到总的损失。
完整代码
class YoloLoss(nn.Module):
def __init__(self, S=7,B=2,C=80):
self.mse = nn.MSELoss(reduction="sum")
# 网络的数量
self.S = S
# 每个网络输出边界框数量
self.B = B
# 每个网格预测类别数
self.C = C
# 在预测边界框中是否存在目标的置信度,因为大部分边界框都是不包含目标,所以对于没有目标网格置信度概率损失减少其在损失函数所占比重
self.lambda_noobj = 0.5
# 对定位损失添加权重,增加定位损失权重
self.lambda_coord = 5
def forward(self, preds, target):
# 输入 shape 为 (batch_size,features) 转换为 (batch_size,cellSize,cellSize,(class_numbers + bbox_numbers * (confidence + x1,y1,w,h)))
preds = preds.reshape(-1,self.S,self.S,self.C + self.B*5)
# 在预测 self.C + 1 到 self.C + 5 为第一个预测框中,中心坐标和边界框宽高,而 self.C + 6 到 self.C + 10
# 为第二个边界框的位置信息,这两边界框分别和标注对应网格位置的边界框做交并比,注意可能边界框并没有目标
iou_b1 = intersection_over_union(preds[...,(self.C + 1):(self.C + 5)],target[...,(self.C + 1):(self.C + 5)])
iou_b2 = intersection_over_union(preds[...,(self.C + 6):(self.C + 10)],target[...,(self.C + 1):(self.C + 5)])
#输出 iou_b1 tensor 为标量,为标量增加一个维度,然后再 0 维度上对 iou_b1 和 iou_b2 进行堆叠
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)],dim=0)
# 比较两个边界框 IoU 返回值 iou_maxes 返回 IoU 最大的值 bestbox 返回最大值对应索引
iou_maxes, bestbox = torch.max(ious,dim=0)
# 对于 target 结构为 (batch_size,cellSize,cellSize,(class_numbers + confidence + (x1,y1,w,h)))
exists_box = target[...,self.C].unsqueeze(3)
# 位置损失
#
box_preds = exists_box * (
(bestbox * preds[...,(self.C + 6):(self.C + 10)] + (1-bestbox)*preds[...,(self.C + 1):(self.C + 5)])
)
box_targets = exists_box * target[...,(self.C + 1):(self.C + 5)]
box_preds[...,2:4] = torch.sign(box_preds[...,2:4]) * torch.sqrt(torch.abs(box_preds[...,2:4] + 1e-6))
box_targets[...,2:4] = torch.sqrt(box_targets[...,2:4])
box_loss = self.mse(
#(n,S,S,4)->(N*S*S,4)
torch.flatten(box_preds,end_dim=-2),
torch.flatten(box_targets,end_dim=-2),
)
"""
目标损失
"""
# 保留 IoU 值高
pred_box = (
bestbox * preds[...,(self.C + 5):(self.C + 6)] + (1-bestbox) * preds[...,(self.C ):(self.C + 1)]
)
# (N*S*S,1)
# 类别为,
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[...,(self.C):(self.C + 1)])
)
# (N,S,S,1) -> (N,S*S)
#
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * preds[...,(self.C ):(self.C + 1)], start_dim=1),
torch.flatten((1 - exists_box) * target[..., (self.C ):(self.C + 1)],start_dim=1)
)
#
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * preds[...,(self.C + 5):(self.C + 6)], start_dim=1),
torch.flatten((1 - exists_box) * target[..., (self.C ):(self.C + 1)],start_dim=1)
)
class_loss = self.mse(
torch.flatten(exists_box*predictions[...,:self.C],end_dim=-2),
torch.flatten(exists_box*target[...,:self.C],end_dim=-2)
)
loss = (
self.lambda_coord * box_loss
+ object_loss
+ self.lambda_noobj * no_object_loss
+ class_loss
)
return loss
我正在参与掘金技术社区创作者签约计划招募活动,点击链接报名投稿。