Unet的网络结构:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import torch.utils.data as Data
seed = 2019
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
import random
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):
return nn.Conv2d(in_channels,out_channels,
kernel_size,padding=0,
bias=bias)
##定义ReLU
def default_relu():
return nn.ReLU(inplace=True)
class Up_Sample(nn.Module):
def __init__(self,in_channels,conv=default_conv,relu=default_relu):
super(Up_Sample,self).__init__()
up1 = nn.Upsample(scale_factor=2,mode='nearest')
up2 = conv(in_channels,in_channels//2,1)
self.module_up = nn.Sequential(up1,up2,relu())
def forward(self,input_down,input_left):
x = self.module_up(input_down)
dif = (input_left.shape[3] - x.shape[3])/2
input_left = input_left[:,:,int(dif):int(dif+x.shape[3]),int(dif):int(dif+x.shape[3])]
return torch.cat((x,input_left),1)
class Unet(nn.Module):
def __init__(self,in_channels,out_channels,conv=default_conv,relu=default_relu,n_feats=64):
super(Unet,self).__init__()
left1 = [conv(in_channels,n_feats,3),relu(),conv(n_feats,n_feats,3)]
left2 = [conv(n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]
left3 = [conv(2*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]
left4 = [conv(4*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]
bottom = [conv(8*n_feats,16*n_feats,3),relu(),conv(16*n_feats,16*n_feats,3)]
right1 = [conv(2*n_feats,n_feats,3),relu(),conv(n_feats,n_feats,3)]
right2 = [conv(4*n_feats,2*n_feats,3),relu(),conv(2*n_feats,2*n_feats,3)]
right3 = [conv(8*n_feats,4*n_feats,3),relu(),conv(4*n_feats,4*n_feats,3)]
right4 = [conv(16*n_feats,8*n_feats,3),relu(),conv(8*n_feats,8*n_feats,3)]
self.left1 = nn.Sequential(*left1)
self.left2 = nn.Sequential(*left2)
self.left3 = nn.Sequential(*left3)
self.left4 = nn.Sequential(*left4)
self.bottom = nn.Sequential(*bottom)
self.right1 = nn.Sequential(*right1)
self.right2 = nn.Sequential(*right2)
self.right3 = nn.Sequential(*right3)
self.right4 = nn.Sequential(*right4)
self.tail = conv(n_feats,out_channels,1)
down = []
for layer in range(4):
down.append(nn.MaxPool2d(kernel_size = 1,stride = 2))
self.down = nn.Sequential(*down)
up = nn.ModuleList()
for layer in range(4):
up.append(Up_Sample(in_channels=(2**(layer+1))*n_feats))
self.up = nn.Sequential(*up)
def forward(self,x):
x1 = self.left1(x)
x1d = self.down[0](x1)
x2 = self.left2(x1d)
x2d = self.down[1](x2)
x3 = self.left3(x2d)
x3d = self.down[2](x3)
x4 = self.left4(x3d)
x4d = self.down[3](x4)
x_b = self.bottom(x4d)
y4d = self.up[3](x_b,x4)
y3 = self.right4(y4d)
y3d = self.up[2](y3,x3)
y2 = self.right3(y3d)
y2d = self.up[1](y2,x2)
y1 = self.right2(y2d)
y1d = self.up[0](y1,x1)
y = self.right1(y1d)
out = self.tail(y)
return out
def main():
model = Unet(in_channels=1,out_channels=2)
from torchsummary import summary
summary(model.cuda(), (1, 572, 572))
if __name__=='__main__':
main()
打印模型:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 640
ReLU-2 [-1, 64, 570, 570] 0
Conv2d-3 [-1, 64, 568, 568] 36,928
MaxPool2d-4 [-1, 64, 284, 284] 0
Conv2d-5 [-1, 128, 282, 282] 73,856
ReLU-6 [-1, 128, 282, 282] 0
Conv2d-7 [-1, 128, 280, 280] 147,584
MaxPool2d-8 [-1, 128, 140, 140] 0
Conv2d-9 [-1, 256, 138, 138] 295,168
ReLU-10 [-1, 256, 138, 138] 0
Conv2d-11 [-1, 256, 136, 136] 590,080
MaxPool2d-12 [-1, 256, 68, 68] 0
Conv2d-13 [-1, 512, 66, 66] 1,180,160
ReLU-14 [-1, 512, 66, 66] 0
Conv2d-15 [-1, 512, 64, 64] 2,359,808
MaxPool2d-16 [-1, 512, 32, 32] 0
Conv2d-17 [-1, 1024, 30, 30] 4,719,616
ReLU-18 [-1, 1024, 30, 30] 0
Conv2d-19 [-1, 1024, 28, 28] 9,438,208
Upsample-20 [-1, 1024, 56, 56] 0
Conv2d-21 [-1, 512, 56, 56] 524,800
ReLU-22 [-1, 512, 56, 56] 0
Up_Sample-23 [-1, 1024, 56, 56] 0
Conv2d-24 [-1, 512, 54, 54] 4,719,104
ReLU-25 [-1, 512, 54, 54] 0
Conv2d-26 [-1, 512, 52, 52] 2,359,808
Upsample-27 [-1, 512, 104, 104] 0
Conv2d-28 [-1, 256, 104, 104] 131,328
ReLU-29 [-1, 256, 104, 104] 0
Up_Sample-30 [-1, 512, 104, 104] 0
Conv2d-31 [-1, 256, 102, 102] 1,179,904
ReLU-32 [-1, 256, 102, 102] 0
Conv2d-33 [-1, 256, 100, 100] 590,080
Upsample-34 [-1, 256, 200, 200] 0
Conv2d-35 [-1, 128, 200, 200] 32,896
ReLU-36 [-1, 128, 200, 200] 0
Up_Sample-37 [-1, 256, 200, 200] 0
Conv2d-38 [-1, 128, 198, 198] 295,040
ReLU-39 [-1, 128, 198, 198] 0
Conv2d-40 [-1, 128, 196, 196] 147,584
Upsample-41 [-1, 128, 392, 392] 0
Conv2d-42 [-1, 64, 392, 392] 8,256
ReLU-43 [-1, 64, 392, 392] 0
Up_Sample-44 [-1, 128, 392, 392] 0
Conv2d-45 [-1, 64, 390, 390] 73,792
ReLU-46 [-1, 64, 390, 390] 0
Conv2d-47 [-1, 64, 388, 388] 36,928
Conv2d-48 [-1, 2, 388, 388] 130
================================================================
Total params: 28,941,698
Trainable params: 28,941,698
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 2275.74
Params size (MB): 110.40
Estimated Total Size (MB): 2387.39
----------------------------------------------------------------