Deeplab-V3(ASPP)

问题提出

在这里插入图片描述在这里插入图片描述

改进的ASPP

在这里插入图片描述
引入global context informationglobal average pooling–>1×1 conv + bn–>bilinearly upsample

  • (a) Atrous Spatial Pyramid Pooling 一个1x1卷积和三个3x3的采样率为rates={6,12,18}的空洞卷积,滤波器数量为256,包含BN层。
  • (b) Image Pooling 图像级特征,即将特征做全局平均池化,经过1×1 conv+bn再上采样
    在这里插入图片描述
    在这里插入图片描述

代码实现

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

class ASPP(nn.Module):
    def __init__(self, dim_in):
        super(ASPP, self).__init__()
        self.dim_in = dim_in
        self.NUM_CONVS_BEFORE_ASPP = 0
        self.NUM_CONVS_AFTER_ASPP = 0
        hidden_dim = 256
        kernel_size = 3
        aspp_dim = 256
        d1, d2, d3 = [6,8,12]
        ROI_XFORM_RESOLUTION=14
        pad_size = kernel_size // 2
        before_aspp_list = []
        after_aspp_list = []
        for _ in range(self.NUM_CONVS_BEFORE_ASPP):
            before_aspp_list.append(
                nn.Conv2d(dim_in, hidden_dim, kernel_size, 1, pad_size)
            )
            before_aspp_list.append(nn.ReLU(inplace=True))
            dim_in = hidden_dim
        if self.NUM_CONVS_BEFORE_ASPP > 0:
            self.conv_before_aspp = nn.Sequential(*before_aspp_list)

        aspp1_list = []
        aspp2_list = []
        aspp3_list = []
        aspp4_list = []
        aspp5_list = []
        feat_list = []

        aspp1_list.extend([
            nn.Conv2d(dim_in, aspp_dim, 1, 1),
            nn.ReLU(inplace=True)
        ])

        aspp2_list.extend([
            nn.Conv2d(dim_in, aspp_dim, 3, 1, d1, dilation=d1),
            nn.ReLU(inplace=True)
        ])

        aspp3_list.extend([
            nn.Conv2d(dim_in, aspp_dim, 3, 1, d2, dilation=d2),
            nn.ReLU(inplace=True)
        ])

        aspp4_list.extend([
            nn.Conv2d(dim_in, aspp_dim, 3, 1, d3, dilation=d3),
            nn.ReLU(inplace=True)
        ])

        aspp5_list.extend([
             nn.AdaptiveAvgPool2d((1, 1)),
             nn.Conv2d(dim_in, aspp_dim, 1, 1),
             nn.ReLU(inplace=True),
         ])

        feat_list.extend([
            nn.Conv2d(aspp_dim * 5, hidden_dim, 1, 1),
            nn.ReLU(inplace=True),
        ])

        self.aspp1 = nn.Sequential(*aspp1_list)
        self.aspp2 = nn.Sequential(*aspp2_list)
        self.aspp3 = nn.Sequential(*aspp3_list)
        self.aspp4 = nn.Sequential(*aspp4_list)
        self.aspp5 = nn.Sequential(*aspp5_list)
        self.feat = nn.Sequential(*feat_list)

        for _ in range(self.NUM_CONVS_AFTER_ASPP):
            after_aspp_list.append(
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, 1, pad_size)
            )
            after_aspp_list.append(nn.ReLU(inplace=True))
        if self.NUM_CONVS_AFTER_ASPP > 0:
            self.conv_after_aspp = nn.Sequential(*after_aspp_list)

        self.dim_out = hidden_dim

    def forward(self, x):
        if self.NUM_CONVS_BEFORE_ASPP > 0:
            x = self.conv_before_aspp(x)

        n,c,h,w = x.shape
        global_feature = self.aspp5(x)
        global_feature = F.interpolate(global_feature, (h,w), None, 'bilinear', True)

        feature_list = [self.aspp1(x), self.aspp2(x), self.aspp3(x), self.aspp4(x), global_feature]

        x = torch.cat(feature_list,1)
        x = self.feat(x)
        return x


if __name__ == '__main__':
    aspp=ASPP(256).cuda()
    x = torch.ones([1,256,14,14]).cuda()
    y=aspp(x)
    print(y.shape)

猜你喜欢

转载自blog.csdn.net/qq_40263477/article/details/106593694