一、论文
《Single Image Reflection Removal Exploiting Misaligned Training Data and Network Enhancements》
从通过玻璃窗捕获的单个图像中去除不希望的反射对于视觉计算系统具有实际重要性。 尽管最先进的方法在某些情况下可以获得不错的结果,但是在处理更一般的实际情况时,性能会大大下降。 这些失败源于单张图像反射去除的内在困难-问题的根本不适性,以及解决基于学习的神经网络管道中这种歧义所需的密集标记的训练数据不足。 在本文中,我们通过利用有针对性的网络增强功能和错位数据的新颖用法来解决这些问题。 对于前者,我们通过嵌入上下文编码模块来增强基线网络体系结构,这些模块能够利用高级上下文线索来减少包含强烈反射的区域内的不确定性。 对于后者,我们引入了对齐不变损失函数,该函数有助于利用容易收集的未对齐现实世界训练数据。 实验结果共同表明,我们的方法在对齐数据方面优于最新技术,并且在使用其他未对齐数据时可能会进行重大改进。
二、网络结构
我关注的是Pyramid Pooling和Residual Block这两个结构
三、代码
代码下载:https://github.com/Vandermode/ERRNet
# Define network components here
import torch
from torch import nn
import torch.nn.functional as F
class PyramidPooling(nn.Module):
def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1):
super().__init__()
self.stages = []
self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales])
self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def _make_stage(self, in_channels, scale, ct_channels):
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
prior = nn.AvgPool2d(kernel_size=(scale, scale))
conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False)
relu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(prior, conv, relu)
def forward(self, feats):
h, w = feats.size(2), feats.size(3)
priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1)
return self.relu(self.bottleneck(priors))
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class DRNet(torch.nn.Module):
def __init__(self, in_channels, out_channels, n_feats, n_resblocks, norm=nn.BatchNorm2d,
se_reduction=None, res_scale=1, bottom_kernel_size=3, pyramid=False):
super(DRNet, self).__init__()
# Initial convolution layers
conv = nn.Conv2d
deconv = nn.ConvTranspose2d
act = nn.ReLU(True)
self.pyramid_module = None
self.conv1 = ConvLayer(conv, in_channels, n_feats, kernel_size=bottom_kernel_size, stride=1, norm=None, act=act)
self.conv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
self.conv3 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=2, norm=norm, act=act)
# Residual layers
dilation_config = [1] * n_resblocks
self.res_module = nn.Sequential(*[ResidualBlock(
n_feats, dilation=dilation_config[i], norm=norm, act=act,
se_reduction=se_reduction, res_scale=res_scale) for i in range(n_resblocks)])
# Upsampling Layers
self.deconv1 = ConvLayer(deconv, n_feats, n_feats, kernel_size=4, stride=2, padding=1, norm=norm, act=act)
if not pyramid:
self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act)
else:
self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
self.pyramid_module = PyramidPooling(n_feats, n_feats, scales=(4,8,16,32), ct_channels=n_feats//4)
self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.res_module(x)
x = self.deconv1(x)
x = self.deconv2(x)
if self.pyramid_module is not None:
x = self.pyramid_module(x)
x = self.deconv3(x)
return x
class ConvLayer(torch.nn.Sequential):
def __init__(self, conv, in_channels, out_channels, kernel_size, stride, padding=None, dilation=1, norm=None, act=None):
super(ConvLayer, self).__init__()
# padding = padding or kernel_size // 2
padding = padding or dilation * (kernel_size - 1) // 2
self.add_module('conv2d', conv(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation))
if norm is not None:
self.add_module('norm', norm(out_channels))
# self.add_module('norm', norm(out_channels, track_running_stats=True))
if act is not None:
self.add_module('act', act)
class ResidualBlock(torch.nn.Module):
def __init__(self, channels, dilation=1, norm=nn.BatchNorm2d, act=nn.ReLU(True), se_reduction=None, res_scale=1):
super(ResidualBlock, self).__init__()
conv = nn.Conv2d
self.conv1 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=act)
self.conv2 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=None)
self.se_layer = None
self.res_scale = res_scale
if se_reduction is not None:
self.se_layer = SELayer(channels, se_reduction)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.se_layer:
out = self.se_layer(out)
out = out * self.res_scale
out = out + residual
return out
def extra_repr(self):
return 'res_scale={}'.format(self.res_scale)
四、相关资料
Single Image Reflection Removal Exploiting Misaligned Training Data and Network Enhancements