论文传送门:CBAM: Convolutional Block Attention Module
CBAM的目的:
为网络添加注意力机制。
CBAM的结构:
①通道注意力机制(Channel attention module):输入特征分别经过全局最大池化和全局平均池化,池化结果经过一个权值共享的MLP,得到的权重相加,最后经过sigmoid激活函数得到通道注意力权重 M c M_c Mc;
②空间注意力机制(Spatial attention module):输入特征在通道维度上分别进行最大池化和平均池化,得到(2,H,W)的特征层,经过7x7的卷积,输出单通道特征层,最后经过sigmoid激活函数得到空间注意力权重 M s M_s Ms;
③二者串联:作者将二者串联搭建,且通道注意力模块在前,空间注意力模块在后。
经过实验,作者发现:串联搭建比并联搭建效果好,先进行通道注意力比先空间注意力效果好。
import torch
import torch.nn as nn
class ChannelAttention(nn.Module): # Channel attention module
def __init__(self, channels, ratio=16): # r: reduction ratio=16
super(ChannelAttention, self).__init__()
hidden_channels = channels // ratio
self.avgpool = nn.AdaptiveAvgPool2d(1) # global avg pool
self.maxpool = nn.AdaptiveMaxPool2d(1) # global max pool
self.mlp = nn.Sequential(
nn.Conv2d(channels, hidden_channels, 1, 1, 0, bias=False), # 1x1conv代替全连接,根据原文公式没有偏置项
nn.ReLU(inplace=True), # relu
nn.Conv2d(hidden_channels, channels, 1, 1, 0, bias=False) # 1x1conv代替全连接,根据原文公式没有偏置项
)
self.sigmoid = nn.Sigmoid() # sigmoid
def forward(self, x):
x_avg = self.avgpool(x)
x_max = self.maxpool(x)
return self.sigmoid(
self.mlp(x_avg) + self.mlp(x_max)
) # Mc(F) = σ(MLP(AvgPool(F))+MLP(MaxPool(F)))= σ(W1(W0(Fcavg))+W1(W0(Fcmax))),对应原文公式(2)
class SpatialAttention(nn.Module): # Spatial attention module
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, 7, 1, 3, bias=False) # 7x7conv
self.sigmoid = nn.Sigmoid() # sigmoid
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True) # 在通道维度上进行avgpool,(B,C,H,W)->(B,1,H,W)
x_max = torch.max(x, dim=1, keepdim=True)[0] # 在通道维度上进行maxpool,(B,C,H,W)->(B,1,H,W)
return self.sigmoid(
self.conv(torch.cat([x_avg, x_max],dim=1))
) # Ms(F) = σ(f7×7([AvgP ool(F);MaxPool(F)])) = σ(f7×7([Fsavg;Fsmax])),对应原文公式(3)
class CBAM(nn.Module): # Convolutional Block Attention Module
def __init__(self, channels, ratio=16):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(channels, ratio) # Channel attention module
self.spatial_attention = SpatialAttention() # Spatial attention module
def forward(self, x):
f1 = self.channel_attention(x) * x # F0 = Mc(F)⊗F,对应原文公式(1)
f2 = self.spatial_attention(f1) * f1 # F00 = Ms(F0)⊗F0,对应原文公式(1)
return f2