本节代码主要来自张老师,由于时代久远自己做了一些修改符合torch1.4
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutil
import matplotlib.pyplot as plt
import numpy as np
image_size = 28
input_dim = 100
num_channels = 1
num_features = 64
batch_size = 64
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
indices = range(len(test_dataset))
indices_val = indices[:5000]
indices_test = indices[5000:]
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)
validation_loader = torch.utils.data.DataLoader(dataset =test_dataset,
batch_size = batch_size,
sampler = sampler_val
)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
sampler = sampler_test
)
class ModelG(nn.Module):
def __init__(self):
super(ModelG, self).__init__()
self.model = nn.Sequential()
self.model.add_module('deconv1',nn.ConvTranspose2d(input_dim,num_features*2,5,2,0,bias=False))
self.model.add_module('bnorm1',nn.BatchNorm2d(num_features*2))
self.model.add_module('relu1',nn.ReLU(True))
self.model.add_module('deconv2',nn.ConvTranspose2d(num_features*2,num_features,5,2,0,bias=False))
self.model.add_module('bnorm2',nn.BatchNorm2d(num_features))
self.model.add_module('relu2',nn.ReLU(True))
self.model.add_module('deconv3',nn.ConvTranspose2d(num_features,num_channels,4,2,0,bias=False))
self.model.add_module('sigmoid',nn.Sigmoid())
def forward(self,x):
output = x
for name,module in self.model.named_children():
output = module(output)
return (output)
def weight_init(m):
class_name = m.__class__.__name__
if class_name.find('conv')!=-1:
m.weight.data.normal_(0,0.002)
if class_name.find('norm') != -1:
m.weight.data.normal_(1.0,0.02)
def make_show(img):
img = img.data.expand(batch_size,3,image_size,image_size)
return img
def img_show(inp,title=None,ax=None):
if inp.size()[0]>1:
inp = inp.numpy().transpose((1,2,0))
else:
inp = inp[0].numpy()
mvalue = np.amin(inp)
maxvalue = np.amax(inp)
if maxvalue > mvalue:
inp = (inp - mvalue) / (maxvalue - mvalue)
ax.imshow(inp)
if title is not None:
ax.set_title(title)
depth = [4,8]
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
def forward(self,x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.view(-1, image_size // 4 * image_size // 4 * depth[1])
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
def retrieve_features(self, x):
feature_map1 = F.relu(self.conv1(x))
x = self.pool(feature_map1)
feature_map2 = F.relu(self.conv2(x))
return (feature_map1, feature_map2)
def rightness(predictions,target):
pred = torch.max(predictions,dim=1)[1]
rights = pred.eq(target.data.view_as(pred)).sum()
return rights,len(target)
netR = torch.load('mnist_conv_checkpoint')
netR = netR.cuda() if use_cuda else netR
for para in netR.parameters():
para.requires_grad = False
print('NOW LETS TRAIN')
netG = ModelG()
netG = netG.cuda() if use_cuda else netG
netG.apply(weight_init)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(netG.parameters(), lr=0.0001, momentum=0.9)
samples1 = np.random.choice(10, batch_size)
samples1 = torch.from_numpy(samples1).type(dtype).requires_grad_(False)
num_epochs = 20
statistics = []
step = 0
for epoch in range(num_epochs):
train_loss = []
train_rights = []
for batch_idx, (data, target) in enumerate(train_loader):
target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
if use_cuda:
target, data = target.cuda(), data.cuda()
label = data.clone()
data = data.type(dtype)
data = data.reshape(data.size()[0], 1, 1, 1)
data = data.expand(data.size()[0], input_dim, 1, 1)
netR.train()
netG.train()
output1 = netG(data)
output = netR(output1)
loss = criterion(output,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
if use_cuda:
loss = loss.cpu()
train_loss.append(loss.data.numpy())
right = rightness(output, label)
train_rights.append(right)
if step % 100 == 0:
netG.eval()
netR.eval()
val_loss = []
val_rights = []
'''开始在校验数据集上做循环,计算校验集上面的准确度'''
for (data, target) in validation_loader:
target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
if use_cuda:
target, data = target.cuda(), data.cuda()
label = data.clone()
data = data.type(dtype)
data = data.reshape(data.size()[0], 1, 1, 1)
data = data.expand(data.size()[0], input_dim, 1, 1)
output1 = netG(data)
output = netR(output1)
loss = criterion(output, label)
if use_cuda:
loss = loss.cpu()
val_loss.append(loss.data.numpy())
right = rightness(output, label)
val_rights.append(right)
train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
print(('训练周期: {} [{}/{} ({:.0f}%)]\t训练数据Loss: {:.6f},正确率: {:.2f}%\t校验数据Loss:' +
'{:.6f},正确率:{:.2f}%').format(epoch, batch_idx * batch_size, len(train_loader.dataset),
100. * batch_idx / len(train_loader), np.mean(train_loss),
100. * train_r[0] / train_r[1],
np.mean(val_loss),
100. * val_r[0] / val_r[1]))
statistics.append({'loss': np.mean(train_loss), 'train': 100. * train_r[0] / train_r[1],
'valid': 100. * val_r[0] / val_r[1]})
samples = samples1.data.reshape(batch_size, 1, 1, 1)
samples = samples.data.expand(batch_size, input_dim, 1, 1)
samples = samples.cuda() if use_cuda else samples
fake_u = netG(samples)
fake_u = fake_u.cpu() if use_cuda else fake_u
img = make_show(fake_u)
vutil.save_image(img, 'temp1/fake%s.png' % (epoch))
result1 = [100 - i['train'] for i in statistics]
result2 = [100 - i['valid'] for i in statistics]
plt.figure(figsize = (10, 7))
plt.plot(result1, label = 'Training')
plt.plot(result2, label = 'Validation')
plt.xlabel('Step')
plt.ylabel('Error Rate')
plt.legend()
plt.show()