开篇
在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。
GAN生成对抗网络
库的引入
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~
设备的配置以及超参数的定义
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~
如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)~
图像的处理和转换
transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],
std=[0.5])])~
我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载
mnist = torchvision.datasets.MNIST(root = '../../data/',
train = True,
transforms = transforms,
download = True)
、
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,
shuffle = True)~
GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义
# Discrimator
D = nn.Sequential(
nn.Linear(image_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size,1),
nn.Sigmoid())
# Generator
G = nn.Sequential(
nn.Linear(latent_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,image_size),
nn.Tanh())
D = D.to(device)
G = G.to(device)~
D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~
这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义
def denorm(x):
out = (x + 1) / 2
# 将out限制在0-1
return out.clamp(0,1)
# 重置梯度
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
训练模型
total_step = len(data_loader)
for epoch in range(num_epochs):
for i,(images,_) in enumerate(data_loader):
images = images.reshape(batch_size,-1).to(device)
# 为计算损失函数生成标签,真是标签是1,虚假标签是0
real_labels = torch.ones(batch_size,1).to(device)
fake_labels = torch.zeros(batch_size,1).to(device)
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
# 这里的损失函数应该都是0,因为我们用真实图片去测试,损失函数一定是0
# 我们的目的是将对的分到real中,错的分到fake中,所以要求2个损失
outputs = D(images)
d_loss_real = criterion(outputs,real_labels)
real_score = outputs
# Compute BCELoss using fake images
# 这里的损失函数应该是1,因为我们用的是虚假图片,且为随机生成的码
z = torch.randn(batch_size,latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs,fake_labels)
fake_score = outputs
# 反向传播优化
d_loss = d_loss_fake + d_loss_real
# 清空梯度
reset_grad()
d_loss.backward()
d_optimizer.step()
# 训练生成器
# 用虚假图片计算损失
z = torch.randn(batch_size,latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
g_loss = criterion(outputs,real_labels)
# 反向传播优化
reset_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 保存真实图像
if (epoch + 1) == 1:
images = images.reshape(images.size(0),1,28,28)
save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))
# 保存样本图像
fake_images = fake_images.reshape(fake_images.size(0),1,28,28)
save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))
保存模型
torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')
总结
GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。