- pytorch实现yolo-v3 (源码阅读和复现) – 001
- pytorch实现yolo-v3 (源码阅读和复现) – 002
- pytorch实现yolo-v3 (源码阅读和复现) – 003算法分析
- pytorch实现yolo-v3 (源码阅读和复现) – 004算法分析
- pytorch实现yolo-v3 (源码阅读和复现) – 005
对上一步模型直接检测层的预测结果进行进一步过滤, 核心还是nms
1.核心代码
def write_results(predictions, confidence, num_class, nms=True, nms_thresh=0.4):
# 保留预测结果中置信度大于给定阈值的部分
# confidence: shape=(1,10647, 85)
# mask: shape=(1,10647) => 增加一维度之后 (1, 10647, 1)
mask = (predictions[:, :, 4] > confidence).float().unsqueeze(2)
predictions = predictions*mask # 小于置信度的条目值全为0, 剩下部分不变
# 如果没有检测任何有效目标,返回值为0
ind_nz = torch.nonzero(predictions[:, :, 4].squeeze()).squeeze()
if ind_nz.size(0) == 0:
return 0
# predictions = predictions[:, ind_nz, :]
'''
保留预测结果中置信度大于阈值的bbox
下面开始为nms准备
'''
# prediction的前五个数据分别表示 (Cx, Cy, w, h, score)
bbox = predictions.new(predictions.shape)
bbox[:, :, 0] = (predictions[:, :, 0] - predictions[:, :, 2]/2) # x1 = Cx - w/2
bbox[:, :, 1] = (predictions[:, :, 1] - predictions[:, :, 3]/2) # y1 = Cy - h/2
bbox[:, :, 2] = (predictions[:, :, 0] + predictions[:, :, 2]/2) # x2 = Cx + w/2
bbox[:, :, 3] = (predictions[:, :, 1] + predictions[:, :, 3]/2) # y2 = Cy + h/2
predictions[:, :, :4] = bbox[:, :, :4] # 计算后的新坐标复制回去
batch_size = predictions.size(0) # dim=0
# output = predictions.new(1, predictions.size(2)+1) # shape=(1,85+1)
write = False # 拼接结果到output中最后返回
for ind in range(batch_size):
# 选择此batch中第ind个图像的预测结果
prediction = predictions[ind]
# 结果过滤
ind_nz = torch.nonzero(prediction[:, 4].squeeze()).squeeze()
if ind_nz.size(0) == 0:
continue
prediction = prediction[ind_nz, :]
# print(prediction.shape) # shape=(10647->14, 85)
# 最大值, 最大值索引, 按照dim=1 方向计算
max_score, max_score_ind = torch.max(prediction[:, 5:], 1) # prediction[:, 5:]表示每一分类的分数
# 维度扩展
# max_score: shape=(10647->14) => (10647->14,1)
max_score = max_score.float().unsqueeze(1)
max_score_ind = max_score_ind.float().unsqueeze(1)
seq = (prediction[:, :5], max_score, max_score_ind) # 取前五
prediction = torch.cat(seq, 1) # shape=(10647, 5+1+1=7)
# print(prediction.shape)
# 获取当前图像检测结果中出现的所有类别
try:
image_classes = unique(prediction[:, -1]) # tensor, shape=(n)
except:
continue
# 执行classwise nms
for cls in image_classes:
# 分离检测结果中属于当前类的数据
# -1: cls_index, -2: score
class_mask = (prediction[:, -1] == cls) # shape=(n)
class_mask_ind = torch.nonzero(class_mask).squeeze() # shape=(n,1) => (n)
# prediction_: shape(n,7)
prediction_class = prediction[class_mask_ind].view(-1, 7) # 从prediction中取出属于cls类别的所有结果,为下一步的nms的输入
''' 到此步 prediction_class 已经存在了我们需要进行非极大值抑制的数据 '''
# 开始 nms
# 按照score排序, 由大到小
# 最大值最上面
score_sort_ind = torch.sort(prediction_class[:, 4], descending=True)[1] # [0] 排序结果, [1]排序索引
prediction_class = prediction_class[score_sort_ind]
cnt = prediction_class.size(0) # 个数
'''开始执行 "非极大值抑制" 操作'''
if nms:
for i in range(cnt):
# 对已经有序的结果,每次开始更新后索引加一,挨个与后面的结果比较
try:
ious = bbox_iou(prediction_class[i].unsqueeze(0), prediction_class[i+1:])
except ValueError:
break
except IndexError:
break
# 计算出需要移除的item
iou_mask = (ious < nms_thresh).float().unsqueeze(1)
prediction_class[i+1:] *= iou_mask # 保留i自身
# 开始移除
non_zero_ind = torch.nonzero(prediction_class[:, 4].squeeze())
prediction_class = prediction_class[non_zero_ind].view(-1, 7)
# iou_mask = (ious < nms_thresh).float() # shape=(n)
# non_zero_ind = torch.nonzero(iou_mask).squeeze()+1 # 会为空,导致出错
# prediction_class = prediction_class[non_zero_ind].view(-1, 7)
# 当前类的nms执行完之后,保存结果
batch_ind = prediction_class.new(prediction_class.size(0), 1).fill_(ind)
seq = batch_ind, prediction_class
if not write:
output = torch.cat(seq, 1)
write = True
else:
out = torch.cat(seq, 1)
output = torch.cat((output, out))
return output
2. 算法分析
第一步显示对目标区域置信度低于阈值的目标(低于阈值认为是bg)剔除掉, 后面的结果在进行nms过滤
在做nms之前, 对bbox坐标进行了变换, 从( Cx, Cy, w,h)变为(x1,y1, x2,y2),这样方便计算 iou
预测结果是一个batch,包含了n>=1张图像, 开始循环{
取出当前图像的预测结果,
过滤掉当前一张图像中置信度低于阈值的结果
统计当前预测结果包含的分类(先做排序,由大到小), 循环{
取出当前图像预测结果中属于当前类的预测结果,
对当前类执行nms,循环
{
取当前第i项和后面[i+1:]分别计算iou, 统计重叠区域大于阈值的部分,剔除掉,
更新预测结果,知道索引越界
}
此时的预测结果中保留的值已经是有效的了,放入到output返回值中之前,需要对其在扩展以为,放入信息:所属图像索引
}
}