ZFNet
ZFNet论文:Visualizing and Understanding Convolutional Networks
ZFNet对AlexNet的网络结构进行了细微调整。
但是这篇论文另一个重要的贡献是提出了对卷积神经网络中间层可视化的方法。
代码展示
import time
import torch
from torch import nn, optim
import torchvision
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ZFNet(nn.Module):
def __init__(self):
super(ZFNet, self).__init__()
self.conv = nn.Sequential(
# 第一层
nn.Conv2d(1, 96, 7, 2),
nn.ReLU(),
nn.MaxPool2d(3, 2),
# 第二次
nn.Conv2d(96, 256, 5, 2),
nn.ReLU(),
nn.MaxPool2d(3, 2),
# 第三层
nn.Conv2d(256, 384, 3, 1, 1),
nn.ReLU(),
# 第四层
nn.Conv2d(384, 384, 3, 1, 1),
nn.ReLU(),
# 第五层
nn.Conv2d(384, 256, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(3, 2),
)
# 全连接层
self.fc = nn.Sequential(
nn.Linear(256 * 5 * 5, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(0.5),
# 输出层。由于这里使用Fashion-MNIST,所以用类别数为10
nn.Linear(4096, 10),
)
def forward(self, img):
feature = self.conv(img)
# print(feature.shape)
output = self.fc(feature.view(img.shape[0], -1))
return output