问题提出
改进的ASPP
引入global context information,global average pooling–>1×1 conv + bn–>bilinearly upsample
- (a) Atrous Spatial Pyramid Pooling 一个1x1卷积和三个3x3的采样率为rates={6,12,18}的空洞卷积,滤波器数量为256,包含BN层。
- (b) Image Pooling 图像级特征,即将特征做全局平均池化,经过1×1 conv+bn再上采样
代码实现
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
class ASPP(nn.Module):
def __init__(self, dim_in):
super(ASPP, self).__init__()
self.dim_in = dim_in
self.NUM_CONVS_BEFORE_ASPP = 0
self.NUM_CONVS_AFTER_ASPP = 0
hidden_dim = 256
kernel_size = 3
aspp_dim = 256
d1, d2, d3 = [6,8,12]
ROI_XFORM_RESOLUTION=14
pad_size = kernel_size // 2
before_aspp_list = []
after_aspp_list = []
for _ in range(self.NUM_CONVS_BEFORE_ASPP):
before_aspp_list.append(
nn.Conv2d(dim_in, hidden_dim, kernel_size, 1, pad_size)
)
before_aspp_list.append(nn.ReLU(inplace=True))
dim_in = hidden_dim
if self.NUM_CONVS_BEFORE_ASPP > 0:
self.conv_before_aspp = nn.Sequential(*before_aspp_list)
aspp1_list = []
aspp2_list = []
aspp3_list = []
aspp4_list = []
aspp5_list = []
feat_list = []
aspp1_list.extend([
nn.Conv2d(dim_in, aspp_dim, 1, 1),
nn.ReLU(inplace=True)
])
aspp2_list.extend([
nn.Conv2d(dim_in, aspp_dim, 3, 1, d1, dilation=d1),
nn.ReLU(inplace=True)
])
aspp3_list.extend([
nn.Conv2d(dim_in, aspp_dim, 3, 1, d2, dilation=d2),
nn.ReLU(inplace=True)
])
aspp4_list.extend([
nn.Conv2d(dim_in, aspp_dim, 3, 1, d3, dilation=d3),
nn.ReLU(inplace=True)
])
aspp5_list.extend([
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(dim_in, aspp_dim, 1, 1),
nn.ReLU(inplace=True),
])
feat_list.extend([
nn.Conv2d(aspp_dim * 5, hidden_dim, 1, 1),
nn.ReLU(inplace=True),
])
self.aspp1 = nn.Sequential(*aspp1_list)
self.aspp2 = nn.Sequential(*aspp2_list)
self.aspp3 = nn.Sequential(*aspp3_list)
self.aspp4 = nn.Sequential(*aspp4_list)
self.aspp5 = nn.Sequential(*aspp5_list)
self.feat = nn.Sequential(*feat_list)
for _ in range(self.NUM_CONVS_AFTER_ASPP):
after_aspp_list.append(
nn.Conv2d(hidden_dim, hidden_dim, kernel_size, 1, pad_size)
)
after_aspp_list.append(nn.ReLU(inplace=True))
if self.NUM_CONVS_AFTER_ASPP > 0:
self.conv_after_aspp = nn.Sequential(*after_aspp_list)
self.dim_out = hidden_dim
def forward(self, x):
if self.NUM_CONVS_BEFORE_ASPP > 0:
x = self.conv_before_aspp(x)
n,c,h,w = x.shape
global_feature = self.aspp5(x)
global_feature = F.interpolate(global_feature, (h,w), None, 'bilinear', True)
feature_list = [self.aspp1(x), self.aspp2(x), self.aspp3(x), self.aspp4(x), global_feature]
x = torch.cat(feature_list,1)
x = self.feat(x)
return x
if __name__ == '__main__':
aspp=ASPP(256).cuda()
x = torch.ones([1,256,14,14]).cuda()
y=aspp(x)
print(y.shape)