SA-NET-轻量级注意力 | SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS

论文地址:https://arxiv.org/pdf/2102.00240.pdf
Github地址:https://github.com/wofmanaf/SA-Net/blob/main/models/sa_resnet.py

在这里插入图片描述

Abstract:

注意力机制使神经网络能够准确地专注于输入的所有相关元素,它已成为改善深度神经网络性能的重要组成部分。在计算机视觉研究中广泛使用的注意力机制主要有两种,即空间注意力和通道注意力,它们分别用于捕获像素级成对关系和通道依赖性。尽管将它们融合在一起可能会比其单独的实现获得更好的性能,但这将不可避免地增加计算开销。在本文中,我们提出了一个有效的Shuffle Attention(SA)模块来解决此问题,该模块采用Shuffle单元有效地结合了两种类型的注意机制。具体而言,SA首先将通道维分组为多个子特征,然后再并行处理它们。然后,对于每个子特征,SA利用shuffle单元在空间和通道维度上描绘特征依赖性。之后,将所有子特征汇总在一起,并使用“通道混洗”运算符来启用不同子特征之间的信息通信。所提出的SA模块既有效又高效,例如,针对骨干网ResNet50的SA的参数和计算分别为300 vs. 25.56M和2.76e-3 GFLOPs vs. 4.12 GFLOPs,并且在TOP-1准确度方面提升在1.34%以上。在常用基准上的大量实验结果,包括用于分类的ImageNet-1k,用于目标检测的MS COCO和实例分割任务上,证明了所提出的SA通过获得较高的精度而具有较低的模型复杂性,从而明显优于当前的SOTA方法。

Introduction:

现有的注意力机制主要包含通道注意力机制(channel attention)和空间注意力机制(spatial attention),以及两者的结合CBAM。本文的想法是提出一个轻量级且高效融合两种注意力机制的综合模块。
在这里插入图片描述

主要动机来源于:

  1. ShuffleNet中的channel shuffle可以高效的保证通道信息交互
  2. SGE中将特征按照通道维度分组

因此,本文结合上述两个动机,提出了一种更轻量级且更高效的shuffle注意力(SA)模块,该模块将通道维度划分为子特征。对于每个子特征,SA均采用shuffle单元同时构建通道注意力和空间注意力。对于每个注意力模块,本文在所有位置上都设计了一个注意力mask,以抑制可能出现的噪声并突出显示正确的语义特征区域。在ImageNet-1k上的实验结果(如图1所示)表明,所提出的简单但有效的模块包含较少的参数,其精度要高于当前的最新方法。

本文的主要贡献概述如下:
1)我们针对深层CNN引入了一个轻量级但有效的注意力模块SA,该模块将通道维度分为多个子特征,然后利用shuffle单元整合互补通道,并每个子特征的空间关注模块。
2)在ImageNet-1k和MS COCO上进行的广泛实验结果表明,所提出的SA与最先进的注意力方法相比,具有较低的模型复杂度,同时具有出色的性能。

Shuffle Attention:

在这里插入图片描述
该结构很简单,流程如下:

1.对输入特征X按照通道维度分组X1…Xg,每个子特征的通道数位c/g

2.对每个子特征继续split成两份,分别提取通道注意力和空间注意力

3.后处理融合

直接看代码,使用parameter代替卷积比较新颖:

class sa_layer(nn.Module):
    """Constructs a Channel Spatial Group module.
    Args:
        k_size: Adaptive selection of kernel size
    """

    def __init__(self, channel, groups=64):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))

        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape

        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.shape

        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)

        # channel attention
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)

        # spatial attention
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)

        # concatenate along channel axis
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)

        out = self.channel_shuffle(out, 2)
        return out

Experiments:

1.Imagenet 1K:
在这里插入图片描述

2.object detection:

在这里插入图片描述
3.Ablation Study:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_42096202/article/details/113774810