GCnet论文详解

传送地址:https://arxiv.org/pdf/1904.11492.pdf

主要思想

如何获取全文信息(global context) 或者叫长距离依赖性。从视觉角度直观理解,比如我们要认出一个人的话可能需要看整个脸部才能认的出来,单给你脸部的一块皮肤,鼻子相对比较难判断。极端一点来说,例如你的朋友离你很远的情况下,需要看身高,衣服来获取整体的信息才能做出判断。
那么对于图像卷积来说,如何获取全文信息呢?这对大物体的识别和分类是非常有帮助的。卷积网络中典型的就是通过卷积层的堆积 (不太懂的可以看这个两个33的卷积为什么能替代55),这种解决方案存在的问题是计算量效率低,更加难优化模型。那么为了缓解这种问题提出了自注意机制(self-attention mechanism),了解NLP的会比较熟悉这个名字。
这里用生活的语言来解释的话可以怎么理解,我们在识别一种物体的话首先是宏观的来看,然后聚焦中某一个点提取信息来识别。比如我们拿女明星来举例子,猛地一看长的都一样,仔细看他的鼻子,眼睛等。就能够分辨。
例子仅供参考,,,,比如像我这种脸盲的

实验结果

  • 自注意力机制
    对应的block主要有三种: SE,NL和GC,NL不是特别熟悉,有大佬了解的话还望留言赐教!
  • 结构图

在这里插入图片描述

  • block实验对比图
    在这里插入图片描述
    自注意力机制的block分为两个模块,一个是获取全文信息,一个是进行信息的转化(transform)。从结构图上来看NL和GC的区别主要在transform模块,看似差别不是很大,但是在实验结果上GC比NL模块高出1.4个点。并且论文中提出通过可视化特征图,NL模块在注意点在不同的位置特征图基本是一样的,可以理解为没有注意到关键点上面??

从实现角度来讲,这三个block不同在哪里呢?transform相对比较好理解,就是在全文信息的基础上转换为0-1之间的权重值。简单直接的方式就是就是在全局求个最大值,平均值就可以在一定程度上代表全文的信息啦,恭喜你,答对了SE block就是怎么做的。
如果全文的信息相互的之间有交互,是不是效果就更好呢?那么GC block就是在该基础上实现的

# 汇集全文的信息 对应的像素点进行匹配,整个图像的像素点全部相加
 context = torch.matmul(input_x, context_mask)

大家调试下面的代码,发现context_mask是0-1之间的值,我们可以理解成它代表这图像中每个像素点的重要性
matmul是点积的操作,大家可以用二维的tensor实验。点积的相乘再相加完成了信息的交互

核心代码

# -*- coding: utf-8 -*-
# @Time    : 2019/5/29 15:30
# @Author  : ljf
from __future__ import absolute_import
import torch
from torch import nn
from mmcv.cnn import constant_init, kaiming_init
import math

def last_zero_init(m):
    if isinstance(m, nn.Sequential):
        # nn.init.constant(m[-1].weight,val=0)
        constant_init(m[-1], val=0)
        m[-1].inited = True
    else:
        constant_init(m, val=0)
        m.inited = True
class ContextBlock2d(nn.Module):

    def __init__(self, inplanes, planes, pool, fusions):
        super(ContextBlock2d, self).__init__()
        assert pool in ['avg', 'att']
        assert all([f in ['channel_add', 'channel_mul'] for f in fusions])
        assert len(fusions) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.planes = planes
        self.pool = pool
        self.fusions = fusions
        if 'att' in pool:
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusions:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
            )
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusions:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
            )
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pool == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pool == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(3)
            # [N, 1, C, 1]
            # 汇集全文的信息 对应的像素点进行匹配,整个图像的像素点全部相加
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = x * channel_mul_term
        else:
            out = x
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out


if __name__ == "__main__":
    inputs = torch.randn(1,16,300,300)
    block = ContextBlock2d(16,4,"att",["channel_add"])
    out = block(inputs)
    print(out.size())

猜你喜欢

转载自blog.csdn.net/weixin_42662358/article/details/90676272
今日推荐