# TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
By [Vladimir Iglovikov](https://www.linkedin.com/in/iglovikov/) and [Alexey Shvets](https://www.linkedin.com/in/shvetsiya/)
# Introduction
TernausNet is a modification of the celebrated UNet architecture that is widely used for binary Image Segmentation. For more details, please refer to our [arXiv paper](https://arxiv.org/abs/1801.05746).
(Network architecure)
Pre-trained encoder speeds up convergence even on the datasets with a different semantic features. Above curve shows validation Jaccard Index (IOU) as a function of epochs for [Aerial Imagery](https://project.inria.fr/aerialimagelabeling/)
This architecture was a part of the [winning solutiuon](http://blog.kaggle.com/2017/12/22/carvana-image-masking-first-place-interview/) (1st out of 735 teams) in the [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge).
https://github.com/ternaus/robot-surgery-segmentation
Jupyter Notebook
Example
Last Checkpoint: 27分钟前
(autosaved)
Current Kernel Logo
Python 3
File
Edit
View
Insert
Cell
Kernel
Widgets
Help
%matplotlib inline
import cv2
from pylab import *
rcParams['figure.figsize'] = 15, 15
import torch
from torch import nn
from unet_models import unet11
from pathlib import Path
from torch.nn import functional as F
from torchvision.transforms import ToTensor, Normalize, Compose
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_model():
model = unet11(pretrained='carvana')
model.eval()
return model.to(device)
def mask_overlay(image, mask, color=(0, 255, 0)):
"""
Helper function to visualize mask on the top of the car
"""
mask = np.dstack((mask, mask, mask)) * np.array(color)
mask = mask.astype(np.uint8)
weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.)
img = image.copy()
ind = mask[:, :, 1] > 0
img[ind] = weighted_sum[ind]
return img
def load_image(path, pad=True):
"""
Load image from a given path and pad it on the sides, so that eash side is divisible by 32 (newtwork requirement)
if pad = True:
returns image as numpy.array, tuple with padding in pixels as(x_min_pad, y_min_pad, x_max_pad, y_max_pad)
else:
returns image as numpy.array
"""
img = cv2.imread(str(path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if not pad:
return img
height, width, _ = img.shape
if height % 32 == 0:
y_min_pad = 0
y_max_pad = 0
else:
y_pad = 32 - height % 32
y_min_pad = int(y_pad / 2)
y_max_pad = y_pad - y_min_pad
if width % 32 == 0:
x_min_pad = 0
x_max_pad = 0
else:
x_pad = 32 - width % 32
x_min_pad = int(x_pad / 2)
x_max_pad = x_pad - x_min_pad
img = cv2.copyMakeBorder(img, y_min_pad, y_max_pad, x_min_pad, x_max_pad, cv2.BORDER_REFLECT_101)
return img, (x_min_pad, y_min_pad, x_max_pad, y_max_pad)
img_transform = Compose([
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def crop_image(img, pads):
"""
img: numpy array of the shape (height, width)
pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad)
@return padded image
"""
(x_min_pad, y_min_pad, x_max_pad, y_max_pad) = pads
height, width = img.shape[:2]
return img[y_min_pad:height - y_max_pad, x_min_pad:width - x_max_pad]
model = get_model()
Example on Carvana dataset car
img, pads = load_image('lexus.jpg', pad=True)
imshow(img)
<matplotlib.image.AxesImage at 0x7f61543bd080>
with torch.no_grad():
input_img = torch.unsqueeze(img_transform(img).to(device), dim=0)
with torch.no_grad():
mask = F.sigmoid(model(input_img))
mask_array = mask.data[0].cpu().numpy()[0]
mask_array = crop_image(mask_array, pads)
imshow(mask_array)
<matplotlib.image.AxesImage at 0x7f61542b0e48>
imshow(mask_overlay(crop_image(img, pads), (mask_array > 0.5).astype(np.uint8)))
<matplotlib.image.AxesImage at 0x7f6154254a58>
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class UNet11(nn.Module):
def __init__(self, num_filters=32, pretrained=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.encoder = models.vgg11(pretrained=pretrained).features
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, 1, kernel_size=1)
def forward(self, x):
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(dec1)
def unet11(pretrained=False, **kwargs):
"""
pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
carvana - all weights are pre-trained on
Kaggle: Carvana dataset https://www.kaggle.com/c/carvana-image-masking-challenge
"""
model = UNet11(pretrained=pretrained, **kwargs)
if pretrained == 'carvana':
state = torch.load('TernausNet.pt')
model.load_state_dict(state['model'])
return model
class Interpolate(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
self.scale_factor = scale_factor
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
mode=self.mode, align_corners=self.align_corners)
return x
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
Interpolate(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class AlbuNet(nn.Module):
"""
UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
"""
def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with resnet34
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
if self.num_classes > 1:
x_out = F.log_softmax(self.final(dec0), dim=1)
else:
x_out = self.final(dec0)
return x_out
class UNet16(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG16
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder[0],
self.relu,
self.encoder[2],
self.relu)
self.conv2 = nn.Sequential(self.encoder[5],
self.relu,
self.encoder[7],
self.relu)
self.conv3 = nn.Sequential(self.encoder[10],
self.relu,
self.encoder[12],
self.relu,
self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17],
self.relu,
self.encoder[19],
self.relu,
self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24],
self.relu,
self.encoder[26],
self.relu,
self.encoder[28],
self.relu)
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
if self.num_classes > 1:
x_out = F.log_softmax(self.final(dec1), dim=1)
else:
x_out = self.final(dec1)
return x_out