前言
论文链接:LRNNET - 轻量级实时语义分割算法
LEDNet中的SS-nbt模块
import torch
import torch.nn as nn
import torch.nn.functional as F
def Split(x):
c = int(x.size()[1])
c1 = round(c * 0.5)
x1 = x[:, :c1, :, :].contiguous()
x2 = x[:, c1:, :, :].contiguous()
return x1, x2
def Merge(x1,x2):
return torch.cat((x1,x2),1)
def Channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
#reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
#flatten
x = x.view(batchsize, -1, height,width)
return x
class SS_nbt_module(nn.Module):
def __init__(self, chann, dropprob, dilated):
super().__init__()
oup_inc = chann//2
#dw
self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
#dw
self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1,dilated))
self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropprob)
# self.channel_shuffle = PermutationBlock(2)
def forward(self, x):
residual = x
x1, x2 = Split(x)
output1 = self.conv3x1_1_l(x1)
output1 = self.relu(output1)
output1 = self.conv1x3_1_l(output1)
output1 = self.bn1_l(output1)
output1_mid = self.relu(output1)
output2 = self.conv1x3_1_r(x2)
output2 = self.relu(output2)
output2 = self.conv3x1_1_r(output2)
output2 = self.bn1_r(output2)
output2_mid = self.relu(output2)
output1 = self.conv3x1_2_l(output1_mid)
output1 = self.relu(output1)
output1 = self.conv1x3_2_l(output1)
output1 = self.bn2_l(output1)
output2 = self.conv1x3_2_r(output2_mid)
output2 = self.relu(output2)
output2 = self.conv3x1_2_r(output2)
output2 = self.bn2_r(output2)
if (self.dropout.p != 0):
output1 = self.dropout(output1)
output2 = self.dropout(output2)
out = Merge(output1, output2)
out = F.relu(residual + out)
# out = self.channel_shuffle(out) ### channel shuffle
out = Channel_shuffle(out, 2) ### channel shuffle
return out
# return ### channel shuffle
if __name__ == '__main__':
ss_nbt = SS_nbt_module(256, 0.2, 6).cuda()
input = torch.randn([1, 256, 14, 14]).cuda()
y = ss_nbt(input)
print(y.shape)
LRNNET中的FCB模块
import torch
import torch.nn as nn
import torch.nn.functional as F
def Split(x):
c = int(x.size()[1])
c1 = round(c * 0.5)
x1 = x[:, :c1, :, :].contiguous()
x2 = x[:, c1:, :, :].contiguous()
return x1, x2
def Merge(x1,x2):
return torch.cat((x1,x2),1)
def Channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
#reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
#flatten
x = x.view(batchsize, -1, height,width)
return x
class FCB_module(nn.Module):
def __init__(self, chann, dropprob, dilated):
super().__init__()
oup_inc = chann//2
#dw
self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03)
#dw
self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3,1), stride=1, padding=(1,0), bias=True)
self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1,3), stride=1, padding=(0,1), bias=True)
self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03)
#ds
self.conv3x3 = nn.Conv2d(chann, chann, (3,3), stride=1, padding=(1*dilated, 1*dilated), bias=True, dilation = (dilated, dilated))
self.conv1x1 = nn.Conv2d(chann, chann, (1,1), stride=1)
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropprob)
# self.channel_shuffle = PermutationBlock(2)
def forward(self, x):
residual = x
x1, x2 = Split(x)
output1 = self.conv3x1_1_l(x1)
output1 = self.relu(output1)
output1 = self.conv1x3_1_l(output1)
output1 = self.bn1_l(output1)
output1_mid = self.relu(output1)
output2 = self.conv1x3_1_r(x2)
output2 = self.relu(output2)
output2 = self.conv3x1_1_r(output2)
output2 = self.bn1_r(output2)
output2_mid = self.relu(output2)
if (self.dropout.p != 0):
output1_mid = self.dropout(output1_mid)
output2_mid = self.dropout(output2_mid)
output = Merge(output1_mid, output2_mid)
output = F.relu(output)
output = self.conv3x3(output)
output = self.relu(output)
output = self.conv1x1(output)
output = self.bn2(output)
output = F.relu(residual + output)
# out = self.channel_shuffle(out) ### channel shuffle
output = Channel_shuffle(output, 2) ### channel shuffle
return output
# return ### channel shuffle
if __name__ == '__main__':
fcb = FCB_module(256, 0.2, 6).cuda()
input = torch.randn([1, 256, 14, 14]).cuda()
y = fcb(input)
print(y.shape)