U N e t 使 用 R e s N e t 系 列 作 为 E n c o d e r UNet使用ResNet系列作为Encoder UNet使用ResNet系列作为Encoder
import functools
import torch.utils.model_zoo as model_zoo
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck
from pretrainedmodels.models.torchvision_models import pretrained_settings
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
classModel(nn.Module):
def __init__(self):super().__init__()
def initialize(self):for m in self.modules():ifisinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
classEncoderDecoder(Model):
def __init__(self, encoder, decoder, activation):super().__init__()
self.encoder = encoder
self.decoder = decoder
ifcallable(activation)or activation is None:
self.activation = activation
elif activation =='softmax':
self.activation = nn.Softmax(dim=1)
elif activation =='sigmoid':
self.activation = nn.Sigmoid()else:
raise ValueError('Activation should be "sigmoid"/"softmax"/callable/None')
def forward(self, x):"""Sequentially pass `x` trough model`s `encoder` and `decoder` (return logits!)"""
x = self.encoder(x)
x = self.decoder(x)return x
def predict(self, x):"""Inference method. Switch model to `eval` mode, call `.forward(x)`
and apply activation function (if activation is not `None`) with `torch.no_grad()`
Args:
x:4D torch tensor with shape (batch_size, channels, height, width)
Return:
prediction:4D torch tensor with shape (batch_size, classes, height, width)"""
if self.training:
self.eval()
with torch.no_grad():
x = self.forward(x)if self.activation:
x = self.activation(x)return x
classUnet(EncoderDecoder):"""Unet_ is a fully convolution neural network for image semantic segmentation
Args:
encoder_name: name of classification model (without last dense layers) used as feature
extractor to build segmentation model.
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
decoder_use_batchnorm:if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
is used.
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
activation: activation function used in ``.predict(x)`` method for inference.
One of [``sigmoid``, ``softmax``, callable, None]
center:if ``True`` add ``Conv2dReLU`` block on encoder head (useful for VGG models)
Returns:
``torch.nn.Module``:**Unet**"""
def __init__(
self,
encoder_name='resnet34',
encoder_weights='imagenet',
decoder_use_batchnorm=True,
decoder_channels=(256,128,64,32,16),
classes=1,
activation='sigmoid',
center=False, # usefull for VGG models
):
encoder =get_encoder(
encoder_name,
encoder_weights=encoder_weights
)
decoder =UnetDecoder(
encoder_channels=encoder.out_shapes,
decoder_channels=decoder_channels,
final_channels=classes,
use_batchnorm=decoder_use_batchnorm,
center=center,)super().__init__(encoder, decoder, activation)
self.name ='u-{}'.format(encoder_name)
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck