最近在做目标检测的项目,本来以为Pytorch会自带iou损失,结果找了一圈没找到,于是自己实现了一下:
需要的同学自取
def iou_loss(predicted_boxes, target_boxes):
"""
计算IoU损失
Args:
predicted_boxes: 预测的bbox坐标,形状为(N, 4),其中N是batch size。
target_boxes: 真实的bbox坐标,形状为(N, 4),其中N是batch size。
Returns:
iou_loss: IoU损失,形状为(N,)。
"""
# 计算预测框和真实框的左上角和右下角坐标
pred_x1, pred_y1, pred_x2, pred_y2 = predicted_boxes[:, 0], predicted_boxes[:, 1], predicted_boxes[:,
2], predicted_boxes[:, 3]
true_x1, true_y1, true_x2, true_y2 = target_boxes[:, 0], target_boxes[:, 1], target_boxes[:, 2], target_boxes[:, 3]
# 计算交集和并集的左上角和右下角坐标
xi1 = torch.max(pred_x1, true_x1)
yi1 = torch.max(pred_y1, true_y1)
xi2 = torch.min(pred_x2, true_x2)
yi2 = torch.min(pred_y2, true_y2)
# 计算交集和并集的面积
inter_area = torch.clamp(xi2 - xi1, min=0) * torch.clamp(yi2 - yi1, min=0)
pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
true_area = (true_x2 - true_x1) * (true_y2 - true_y1)
union_area = pred_area + true_area - inter_area
# 计算IoU
iou = inter_area / union_area
# 计算IoU损失
iou_loss = 1.0 - iou
return iou_loss
需要注意的是,本代码的损失是基于框的左上角和右下角两个点的坐标值的:
所以小伙伴在使用的时候也要注意,传入的数据格式