GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

the Best Paper Award at ICCV 2019 Neural Architects Workshop

论文链接:https://arxiv.org/pdf/1904.11492v1.pdf

主要思想:作者发现在non local(https://blog.csdn.net/breeze_blows/article/details/104715120)中,对于不同的query positions,他们的attention map几乎是一样的,即产生的global contexts是差不多的,也就是说global contexts是query positions独立的,利用这个性质还有借助SENET(https://blog.csdn.net/breeze_blows/article/details/104834567)的思想,作者简化了non local提出GCNET,而且将SENET,GCNET,简化之后的non local归为一个结构模式,在object detection上面的实现表示,GCNET需要的参数更少,而且有着与non locol相似的精度。

 就像上图中看到的不同的query position的attention基本一样,按照文中的说法就是for different query positions, their attention maps are almost the same. the global context after training is actually independent of query position,there is no need to compute query-specific global context for each query position,而且文中还对不同query position的attention map计算了距离,发现确实很接近。

 non local的可以用公式表述为

 首先是根据query的attention map是位置独立的,将non local简化为

为了进一步简化,将Wv提到外面

The FLOPs of the 1x1 conv Wv is reduced from O(HWC2 ) to O(C2 ).这种形式就对应上图4中的b,进一步将上述的图4的几种结构归到一个框架下

 一共由三部分组成context modeling,transform,fusion。具体描述感觉文中写的很清晰

 GCNET最后可以公式表示为

至于为什么要在Wv1,Wv2也就是两个conv1*1中间加LN,其实normalize还有很多方式比如BN,GN,IN,为什么是LN可能是尝试出来的吧,文中的解释是这样更好优化一点,即网络更好训练,As the two-layer bottleneck transform increases the difficulty of optimization, we add layer normalization inside the bottleneck transform (before ReLU) to ease optimization, as well as to act as a regularizer that can benefit generalization.

最后作者对比non local做了对比试验,baseline是mask rcnn,+1表示在resnet的最后一个block进行添加,默认是加在block最后一个conv1*1之前,+all表示对resnet所有block都添加,可以看出GC的参数量更少,但是可以取得和non local差不多的精度。

 试验表明GC用在resnet 的C3,C4,C5效果最好,对图4d transform的r进行了实验,r越大精度越好,参数量越多,最后选取4

 最后就是关于pool与fusion的对比实验,这里的att就对应着图4d中的tranform操作,可以发现att+scale仅仅比avg+scale效果好了一点点,这证明,fusion的影响远远大于pooling,就像文中说的how global context is aggregated to query positions (choice of fusion module) is more important than how features from all positions are grouped together (choice in context modeling module)

 GCNET源码:https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py

感觉网络层的初始化方式有点奇怪,不知道怎么来的

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nn

def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def last_zero_init(m):
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att': #att
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1) #avg
        if 'channel_add' in fusion_types: #add
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types: #scale
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x) #pool

        out = x
        if self.channel_mul_conv is not None: #scale
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None: #add
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

其他:

  • GC block的transform中两个conv中间加了normalization,但是用了LN而不是BN,WN等,可能是实验出来的吧
  • 不同的query的attention map是差不多的,可以看看知乎大佬们的讨论:https://zhuanlan.zhihu.com/p/64988633,其中https://www.zhihu.com/people/li-xia-zhi-guang说的我觉得蛮有道理:我的理解是任务不一样吧。语义分割需要对每个像素都输出,所以要“雨露均沾”。而分类只需关注最重要的概念就OK了;而检测的话,正例数量远远小于反例。所以只focus正例也没毛病。

猜你喜欢

转载自blog.csdn.net/breeze_blows/article/details/104878928