【mmdetection】代码对比
在这里主要对比一下solo的官方代码和复现版本。
官方代码 https://github.com/WXinlong/SOLO
复现版本 https://github.com/Epiphqny/SOLO
- config
img_norm_cfg = dict(
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
-
scale_ranges表示不同
官方代码表示((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
复现版本表示((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF)) INF = 1e8
-
cls branch不同
官方代码每层共享一个cls branch;复现版本每层都有一个cls branch;前者就会少一些参数 -
FPN设置
我自己在复现版本上改了FPN的采样率;而官方代码没有改FPN,直接把把输出的特征图插值 -
coord编码方式不同
官方代码的coord写法更为简略,复现版本没有该部分,我参考CoordConv
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
-
质心求法不同
官方代码通过scipy里的ndimage.measurements.center_of_mass
函数直接求得,复现版本没有该部分,我参考polarmask求质心的方法。 -
mask branch
官方代码与论文一致,每一层的mask经上采样2倍再计算loss,预测mask和target mask层层对应,s^2处于维度1。而复现版本将所有的正样本对应的mask都resize到原图的四分之一统一计算loss。 -
focal loss
官方代码中avg_factor
为总共实例个数+1,dice loss是对每一个层计算一下算平均;复现版本中是总共实例个数+batch数 -
label设置方式
官方代码巧用zip,实现列表维度的转换,官方代码以gt mask的重心为出发点,来分配label。而复现版本先求出特征图上点对应原图的坐标, 再计算这些点到重心点的距离,看是否在范围内,有点繁琐。
关于代码中zip部分可见【python】zip和*。
给一张图分配label的代码如下:
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
# featmap_sizes为5层mask输出的大小
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1])) # 边
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # torch.Size([s^2, h, w])
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) # torch.Size([s, s])
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device) # torch.Size([s^2])
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() # gt box中属于该层范围的gt索引
if len(hit_indices) == 0:
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
output_stride = stride / 2 # mask分支是输出特征图的2倍
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() < 10:
continue # 过滤掉一些太小的mask
# mass center
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) # 原图大小
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) # 用mask求质心
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) # 属于该层哪个grid
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
return ins_label_list, cate_label_list, ins_ind_label_list
-
测试阶段
官方代码mask branch的输出都插值为原图的1/4。cls branch先进行一下points_nms
(用2x2的窗口对特征图进行池化,留下池化后的像素值和特征图一样的像素)。 -
nms操作
先贴一下matrix nms的流程图
那就得先对mask打分得到seg_scores
,就是该张图片大于0.5的像素预测值的和/大于0.5的像素个数
。最后的分类的分数cate_scores
也得乘上seg_scores
。
将cate_scores
降序排列选出前n个进行matrix nms。
def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
"""Matrix NMS for multi-class masks.
Args:
seg_masks (Tensor): shape (n, h, w)
cate_labels (Tensor): shape (n), mask labels in descending order
cate_scores (Tensor): shape (n), mask scores in descending order
kernel (str): 'linear' or 'gauss'
sigma (float): std in gaussian method
sum_masks (Tensor): The sum of seg_masks
Returns:
Tensor: cate_scores_update, tensors of shape (n)
"""
n_samples = len(cate_labels)
if n_samples == 0:
return []
if sum_masks is None:
sum_masks = seg_masks.sum((1, 2)).float()
seg_masks = seg_masks.reshape(n_samples, -1).float()
# inter.
inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) # 矩阵乘法
# union.
sum_masks_x = sum_masks.expand(n_samples, n_samples)
# iou.
iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1) # (n,n)
'''
sum_masks_x + sum_masks_x.transpose(1, 0)
1 2 3 1 1 1
1 2 3 + 2 2 2
1 2 3 3 3 3
'''
# label_specific matrix.
cate_labels_x = cate_labels.expand(n_samples, n_samples)
label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1)
# 每一行代表和该mask类别一样的mask
# IoU compensation 要使3)式最小 那么iou就取最大
compensate_iou, _ = (iou_matrix * label_matrix).max(0) # 按列看去代表每个mask的(iou.,i)
compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0) # 每一行代表每个mask的(iou.,i)
# IoU decay
decay_iou = iou_matrix * label_matrix
# 每一行代表和其他mask(该mask类别一样且分数比该mask低)的iou 那么每一列就代表了论文里公式4)的iou_{ij}
# matrix nms
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
elif kernel == 'linear':
decay_matrix = (1-decay_iou)/(1-compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError
# update the score.
cate_scores_update = cate_scores * decay_coefficient
return cate_scores_update
- loss
def loss(self,
ins_preds,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds]
ins_label_list, cate_label_list, ins_ind_label_list = multi_apply( # 每一张图片的5层输出的label
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# ins_label_list : (batch_size)(5)(torch.Size([s^2, h, w]))
# cate_label_list : (batch_size)(5)(s,s)
# ins_ind_label_list : (batch_size)(5)(s^2) 代表s^2哪一个有mask
# ins
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
# ins_labels : (5)(torch.Size([num_obj, h, w])) 将一个batch都按先后顺序放在一起算,num_obj有可能为0
ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
for ins_preds_level_img, ins_ind_labels_level_img in
zip(ins_preds_level, ins_ind_labels_level)], 0)
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
# ins_preds : (5)(torch.Size([batch_size, s^2, h, w])) to (5)(torch.Size([num_obj, h, w]))
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_preds, ins_labels):
if input.size()[0] == 0:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean() # 5个层算一下平均
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)