《Generative Adversarial Nets》
- 生成式对抗网络;
- 作者:lan Goodfellow;
- 单位:加拿大蒙特利尔大学;
- 发表会议及时间:NeurlPS(NIPS) 2014;
核心要点
- 提出了一个基于对抗的 新生成式模型,由一个生成器和一个判别器组成;
- 生成器的目标是学习到样本的数据分布,从而能生成样本欺骗判别器;判别器的目标是判断输入样本时生成/真实的概率;
- GAN模型等同于博弈论中的二人零和博弈;
- 对于任意的生成器和判别器,都存在一个独特的全局最优解;
- 在本文中,生成器和判别器都是由多层感知机实现,整个网络可以用反向传播算法来训练;
- 通过实验的定性与定量分析显示,GAN具备很大的潜力;
研究背景
1、零和博弈
- 一方的收益必然意味着另一方的损失,博弈各方的收益和损失相加总和永远为“零”,双方不存在合作的可能;
- 在零和博弈中,为了使己方达到最优解,所以把目标设为让对方的最大化收益最小化;
2、使用数据集
-
MNIST:手写数据集,源自NIST;28*28的灰度图,训练集60000张,测试集10000张;
-
TFD:The Toronro face dataset,人脸数据集;
-
CIFAR-10:32*32彩图,10个类别,每类6000张图,训练集50000张,测试集10000张;
3、GAN价值函数
价值函数
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_G max_D V(D,G)=E_{x\sim p_{data}(x)}[log D(x)]+E_{z\sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
- d a t a data data:真实数据;
- D D D:判别器,输出值为[0,1],代表输入来自真实数据的概率;
- z z z:随机噪声;
- G G G:生成器,输出为合成数据;
判别器 D D D的目的是最大化价值函数 V V V,对数函数log在底数大于1时为单调递增函数,最大化 V V V就是最大化 D ( x ) D(x) D(x)和 1 − D ( G ( z ) ) 1-D(G(z)) 1−D(G(z)),对于任意的x,都有 D ( x ) = 1 D(x)=1 D(x)=1,对于任意的 z z z都有 D ( G ( z ) ) = 0 D(G(z))=0 D(G(z))=0。
生成器 G G G的目的是针对特定的 D D D,去最小化价值函数 V V V;最小化价值函数 V V V,就是最小化 D ( x ) D(x) D(x)和 1 − D ( G ( z ) ) 1-D(G(z)) 1−D(G(z));对于任意的 z z z,都有 D ( G ( z ) ) = 1 D(G(z))=1 D(G(z))=1。
训练小trick
- 在开始训练的时候,生成器 G G G的性能较差, D ( G ( z ) ) D(G(z)) D(G(z))接近于0,此时价值函数中的 l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1−D(G(z)))的梯度值较小,而 l o g ( D ( G ( z ) ) ) log(D(G(z))) log(D(G(z)))的梯度值较大,所以可以把生成器 G G G的目标改为最大化 l o g D ( G ( z ) ) logD(G(z)) logD(G(z)),这样可以在早期学习中提供更强的梯度。
4、训练流程
- 使用mini-batch梯度下降(带momentum);
- 训练k次判别器(本论文实验中k=1);
- 训练1次生成器;
根据伪代码可以知道,对应两个神经网络模型——生成器 G G G和判别器 D D D,首先会固定生成器 G G G的参数,使用生成器 G G G生成的数据和真实的数据训练判别器 D D D,训练k次判别器 D D D后,固定判别器 D D D的参数,训练生成器 G G G。
理想情况下,判别器的最优解为: D G ∗ ( x ) = P d a t a ( x ) P d a t a ( x ) + P g ( x ) D^*_{G}(x)=\frac{P_{data}(x)}{P_{data}(x)+P_g(x)} DG∗(x)=Pdata(x)+Pg(x)Pdata(x)判别器取得最优解时,生成器的最优解为: P g = P d a t a P_g=P_{data} Pg=Pdata此时价值函数的值为 C ∗ = − l o g ( 4 ) C^*=-log(4) C∗=−log(4)
模型优劣势
缺点:
- 没有显式表示的 P g ( x ) P_g(x) Pg(x);
- 必须同步训练G和D,可能会发生模式崩溃;
优点:
- 不使用马尔科夫链,在学习过程中不需要推理;
- 可以将多种函数合并到模型中;
- 可以表示非常尖锐、甚至退化的分布;
- 不是直接使用数据来计算loss更新生成器,而是使用判别器的梯度,所以数据不会直接复制到生成器的参数中;
Pytorch代码
# 代码来源:https://github.com/eriklindernoren/PyTorch-GAN
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") # 迭代次数
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") # 批量大小
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") # adam的学习率
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") # 动量法
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") # 生成器输入维度
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") # 照片的尺寸
parser.add_argument("--channels", type=int, default=1, help="number of image channels") # 通道数,1表示灰度图
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") # 采样照片频率
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module): # 生成器
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
class Discriminator(nn.Module): # 判别器
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Loss function
adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True, # 训练模式
download=True, # 如果MNIST没有下载则直接下载
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
), # 照片处理方式
), # 数据集
batch_size=opt.batch_size, # 训练数据批量大小
shuffle=True, # 是否打乱
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 生成器的优化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 判别器的优化器
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真实数据的label
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成数据的label
# Configure input
real_imgs = Variable(imgs.type(Tensor)) # 真实照片
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad() #
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成随机分布的数据
# Generate a batch of images
gen_imgs = generator(z) # 生成器生成伪照片
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 生成器的目的是骗过判别器,所以希望生成器生成的照片被预测为1
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid) # 判别器希望真实的照片预测为1
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 判别器希望伪造的照片预测为0
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
os.makedirs("model", exist_ok=True)
torch.save(generator, 'model/generator.pkl')
torch.save(discriminator, 'model/discriminator.pkl')