【mmdetection】SOLO官方代码 vs 复现版本

【mmdetection】代码对比

在这里主要对比一下solo的官方代码和复现版本。
官方代码 https://github.com/WXinlong/SOLO
复现版本 https://github.com/Epiphqny/SOLO

  • config
img_norm_cfg = dict(
    mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
  • scale_ranges表示不同
    官方代码表示((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
    复现版本表示((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF)) INF = 1e8

  • cls branch不同
    官方代码每层共享一个cls branch;复现版本每层都有一个cls branch;前者就会少一些参数

  • FPN设置
    我自己在复现版本上改了FPN的采样率;而官方代码没有改FPN,直接把把输出的特征图插值

  • coord编码方式不同
    官方代码的coord写法更为简略,复现版本没有该部分,我参考CoordConv

x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
  • 质心求法不同
    官方代码通过scipy里的ndimage.measurements.center_of_mass函数直接求得,复现版本没有该部分,我参考polarmask求质心的方法。

  • mask branch
    官方代码与论文一致,每一层的mask经上采样2倍再计算loss,预测mask和target mask层层对应,s^2处于维度1。而复现版本将所有的正样本对应的mask都resize到原图的四分之一统一计算loss。

  • focal loss
    官方代码中avg_factor为总共实例个数+1,dice loss是对每一个层计算一下算平均;复现版本中是总共实例个数+batch数

  • label设置方式
    官方代码巧用zip,实现列表维度的转换,官方代码以gt mask的重心为出发点,来分配label。而复现版本先求出特征图上点对应原图的坐标, 再计算这些点到重心点的距离,看是否在范围内,有点繁琐。
    关于代码中zip部分可见【python】zip和*
    给一张图分配label的代码如下:

 def solo_target_single(self,
                               gt_bboxes_raw,
                               gt_labels_raw,
                               gt_masks_raw,
                               featmap_sizes=None):
        # featmap_sizes为5层mask输出的大小
        device = gt_labels_raw[0].device
        # ins
        gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
                gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1])) # 边

        ins_label_list = []
        cate_label_list = []
        ins_ind_label_list = []
        for (lower_bound, upper_bound), stride, featmap_size, num_grid \
                in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):

            ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # torch.Size([s^2, h, w])
            cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) # torch.Size([s, s])
            ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device) # torch.Size([s^2])
            hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() # gt box中属于该层范围的gt索引

            if len(hit_indices) == 0:
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                ins_ind_label_list.append(ins_ind_label)
                continue
            gt_bboxes = gt_bboxes_raw[hit_indices]
            gt_labels = gt_labels_raw[hit_indices]
            gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...] 

            half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
            half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma

            output_stride = stride / 2 # mask分支是输出特征图的2倍

            for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
                if seg_mask.sum() < 10:
                   continue  # 过滤掉一些太小的mask
                # mass center
                upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) # 原图大小
                center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) # 用mask求质心
                coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) # 属于该层哪个grid
                coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down
                top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
                left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))

                top = max(top_box, coord_h-1)
                down = min(down_box, coord_h+1)
                left = max(coord_w-1, left_box)
                right = min(right_box, coord_w+1)

                cate_label[top:(down+1), left:(right+1)] = gt_label
                # ins
                seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
                seg_mask = torch.Tensor(seg_mask)
                for i in range(top, down+1):
                    for j in range(left, right+1):
                        label = int(i * num_grid + j)
                        ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
                        ins_ind_label[label] = True
            ins_label_list.append(ins_label)
            cate_label_list.append(cate_label)
            ins_ind_label_list.append(ins_ind_label)
        return ins_label_list, cate_label_list, ins_ind_label_list
  • 测试阶段
    官方代码mask branch的输出都插值为原图的1/4。cls branch先进行一下points_nms(用2x2的窗口对特征图进行池化,留下池化后的像素值和特征图一样的像素)。

  • nms操作
    先贴一下matrix nms的流程图
    在这里插入图片描述
    那就得先对mask打分得到seg_scores,就是该张图片大于0.5的像素预测值的和/大于0.5的像素个数。最后的分类的分数cate_scores也得乘上seg_scores
    cate_scores降序排列选出前n个进行matrix nms。

def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
    """Matrix NMS for multi-class masks.

    Args:
        seg_masks (Tensor): shape (n, h, w)
        cate_labels (Tensor): shape (n), mask labels in descending order
        cate_scores (Tensor): shape (n), mask scores in descending order
        kernel (str):  'linear' or 'gauss' 
        sigma (float): std in gaussian method
        sum_masks (Tensor): The sum of seg_masks

    Returns:
        Tensor: cate_scores_update, tensors of shape (n)
    """
    n_samples = len(cate_labels)
    if n_samples == 0:
        return []
    if sum_masks is None:
        sum_masks = seg_masks.sum((1, 2)).float()
    seg_masks = seg_masks.reshape(n_samples, -1).float() 
    # inter.
    inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) # 矩阵乘法
    # union.
    sum_masks_x = sum_masks.expand(n_samples, n_samples)
    # iou.
    iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1) # (n,n)
    '''
    sum_masks_x + sum_masks_x.transpose(1, 0)
    1 2 3      1 1 1
    1 2 3   +  2 2 2 
    1 2 3      3 3 3
    '''
    # label_specific matrix.
    cate_labels_x = cate_labels.expand(n_samples, n_samples)
    label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) 
    # 每一行代表和该mask类别一样的mask

    # IoU compensation 要使3)式最小 那么iou就取最大
    compensate_iou, _ = (iou_matrix * label_matrix).max(0) # 按列看去代表每个mask的(iou.,i)
    compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0) # 每一行代表每个mask的(iou.,i)

    # IoU decay 
    decay_iou = iou_matrix * label_matrix 
    # 每一行代表和其他mask(该mask类别一样且分数比该mask低)的iou  那么每一列就代表了论文里公式4)的iou_{ij}

    # matrix nms
    if kernel == 'gaussian':
        decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
        compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
        decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
    elif kernel == 'linear':
        decay_matrix = (1-decay_iou)/(1-compensate_iou)
        decay_coefficient, _ = decay_matrix.min(0)
    else:
        raise NotImplementedError

    # update the score.
    cate_scores_update = cate_scores * decay_coefficient
    return cate_scores_update
  • loss
def loss(self,
             ins_preds,
             cate_preds,
             gt_bbox_list,
             gt_label_list,
             gt_mask_list,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in
                         ins_preds]
        
        ins_label_list, cate_label_list, ins_ind_label_list = multi_apply( # 每一张图片的5层输出的label 
            self.solo_target_single,
            gt_bbox_list,
            gt_label_list,
            gt_mask_list,
            featmap_sizes=featmap_sizes)
        # ins_label_list : (batch_size)(5)(torch.Size([s^2, h, w]))
        # cate_label_list : (batch_size)(5)(s,s)
        # ins_ind_label_list : (batch_size)(5)(s^2) 代表s^2哪一个有mask 

        # ins 
        ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
                                 for ins_labels_level_img, ins_ind_labels_level_img in
                                 zip(ins_labels_level, ins_ind_labels_level)], 0)
                      for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
        
        # ins_labels : (5)(torch.Size([num_obj, h, w])) 将一个batch都按先后顺序放在一起算,num_obj有可能为0
        ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
                                for ins_preds_level_img, ins_ind_labels_level_img in
                                zip(ins_preds_level, ins_ind_labels_level)], 0)
                     for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
        # ins_preds : (5)(torch.Size([batch_size, s^2, h, w])) to (5)(torch.Size([num_obj, h, w])) 

        ins_ind_labels = [
            torch.cat([ins_ind_labels_level_img.flatten()
                       for ins_ind_labels_level_img in ins_ind_labels_level])
            for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        flatten_ins_ind_labels = torch.cat(ins_ind_labels)

        num_ins = flatten_ins_ind_labels.sum()
        
        # dice loss
        loss_ins = []
        for input, target in zip(ins_preds, ins_labels):
            if input.size()[0] == 0:
                continue
            input = torch.sigmoid(input)
            loss_ins.append(dice_loss(input, target))
        loss_ins = torch.cat(loss_ins).mean() # 5个层算一下平均
        loss_ins = loss_ins * self.ins_loss_weight

        # cate
        cate_labels = [
            torch.cat([cate_labels_level_img.flatten()
                       for cate_labels_level_img in cate_labels_level])
            for cate_labels_level in zip(*cate_label_list)
        ]
        flatten_cate_labels = torch.cat(cate_labels)

        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
            for cate_pred in cate_preds
        ]
        flatten_cate_preds = torch.cat(cate_preds)

        loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
        return dict(
            loss_ins=loss_ins,
            loss_cate=loss_cate)
发布了86 篇原创文章 · 获赞 10 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_36530992/article/details/105302200