前言
这一系列的博客,主要记录一下关于GAN的学习。已经有很多有关GAN的博客了。并且写的都特别好。而我还有写的目的,只是想记录一下自己的学习过程,自己学到的知识。
一、GAN的思想
生成对抗网络是 Ian Goodfellow在2014年提出的一种新模型。最初,是为了通过神经网络来生成数据。因此,Ian被称为生成对抗网络之父。生成对抗网络由两部分组成,分别是生成器和判别器。在训练的过程中,需要两者很好的配合。生成器就像一个制作假钞的坏人,它的成长过程是从一个零基础的小白慢慢成长为一个制作假钞的高手。而判别器就是一个识别假钞的警察,最开始也是一个识别假钞的普通警察。通过与制作假钞的坏人不断博弈,相互促进。最终达到识别假钞的警察无法识别出,假钞的真假。
上面通过一个简单的例子描述了一下GAN的思想。下面我们通过李宏毅老师的课程中的一个例子,来详细的看一下GAN的思想。
首先,生成器和判别器都是一个神经网络。在最初的论文中,使用的是MLP。生成器的输入是噪声,噪声输入到生成器中,输出是一个图片。此时的生成器能力还很差,生成的图片质量很差,然后将真实的图片与生成的图片一起输入到判别器中,让判别器去判别那个是真那个是假。当判别器v1能够很好地判别出图像的真假时,固定判别器v1的参数。继续训练生成器,不断的调整生成器的参数,直到判别器不能判别出生成的图像是假图像。此时的生成器v2生成图片要比之前生成器v1好一些。然后继续训练判别器,一样达到判别器很好的判别生成器v2的图片是假图片。以此类推,不断的进行下去。最后。达到的转态是判别器无法识别出图片的真假。此时生成器生成的图片和真实图片基本类似。算法过程如下:
二、GAN的公式推导
我们先了解一下三个主要公式:
连续函数的数学期望:
KL散度:
JSD散度:
接下来,我们就看看有关GAN的数学原理。我们将真实训练数据集定义为一个概率分布函数。同样,为了逼近真实数据的概率分布,定义一个生成模型的概率分布函数。其计算方法为:
在实际的运算中,我们是无法知道的形式的,我们可以做到的是从真实数据中采样大量的数据,也就是从中取出{},根据这些真实的数据,给定,计算。那么生成这m个样本数据的似然是
因此,我们的目标就是通过上面这个概率的式子,寻找出一个使得L最大化。这样做的实际含义是指,在给出真实训练集的前提下,我们希望生成模型能够在这些数据上具备最大的概率,这样才说明我们的生成模型在给出的训练集上能够逼近真实数据的概率分布。对于上边的连乘,我们取对数,变成连加,这样会更好计算。
然后把求和近似转化为求期望。并且写成积分形式。
然后,在不影响上式求解的情况下,减去一个与没有关系的常数项。
最后,我们希望最小化真实数据分布与生成数据分布之间的KL散度。从而使得生成模型分布接近真实数据分布。但是,这种方式的生成模型通常会比较模糊。原因是这样的模型太简单,无法使生成模型分布逼近真实数据的分布。可以采用神经网络(GAN)来解决这个问题。
GAN中,有生成器G,给定先验分布,希望得到生成分布。判别器D是一个函数,衡量与之间的差距。
定义V(G,D)
对于上述积分,取其最大值,我们希望对于给定x,积分里的项是最大的,也就是希望取一个最大的,最大化下面的式子。数据给定,G给定的情况下,和可看做常数,用a,b表示。可得如下式子:
可以看出,上述的值是一个0到1之间的值,当生成数据分布与真实数据分布非常接近的时候,应该输出的结果为。得到了给定G,求得使V(D)取得最大值的D,将D带回V(G,D)。
当且仅当等于时,可以取得全局最小值-log(4)。得到我们要求的最优生成器,而我们所要求的最优生成器,正是使得的分布等于
三、代码实现
1. 导包
from __future__ import print_function, division
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
from skimage.io import imsave
import matplotlib.pyplot as plt
import os
import numpy as np
2. 初始化
class GAN():
def __init__(self):
# 数据形状
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# 构建并且编译判别器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 构建生成器
self.generator = self.build_generator()
# 给生成器输入噪声,得到图像
z = Input(shape=(100,))
img = self.generator(z)
# 控制判别器不被训练
self.discriminator.trainable = False
# 判别器判别的结果
validity = self.discriminator(img)
# 编译生成器和判别器组合在一起,训练生成器
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
3. 构建生成器
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
4. 构建判别器
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
5. 训练
def train(self, epochs, batch_size=128, sample_interval=50):
# 加载数据
X_train, X_label = self.load_data("./MNIST_data/")
# 归一化到-1到1之间
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 设置空间
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# 训练判别器
# ---------------------
# 选择图像训练的批量
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# 生成新图像的批量
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = self.generator.predict(noise)
# 训练判别器
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# 训练生成器
# ---------------------
noise = np.random.normal(0, 1, (batch_size, 100))
# 训练生成器,并且用判别器的结果作为标签
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
6. 加载数据
# 加载数据
def load_data(self, data_path):
'''
函数功能:加载数据
:param data_path: 数据存在路径
:return: train_data:训练数据,形状为(60000,28,28,1)
train_label:标签数据,形状为(60000,1)
'''
# 获取训练数据
f_data = open(os.path.join(data_path, 'train-images.idx3-ubyte'))
loaded_data = np.fromfile(file=f_data, dtype=np.uint8)
# 前16个字符是说明符,需要跳过
train_data = loaded_data[16:].reshape((-1, 28, 28)).astype(np.float)
# 获取标签数据
f_label = open(os.path.join(data_path, 'train-labels.idx1-ubyte'))
labeled_data = np.fromfile(file=f_label, dtype=np.uint8)
# 前8个字符需要跳过,
train_label = labeled_data[8:].reshape((-1)).astype(np.float)
return train_data, train_label
7. 展示结果
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
for m in range(1, 25):
fname = "image/%d_%d.png" % (epoch, m)
imsave(fname, gen_imgs[m, :, :, 0])
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
8. 运行代码
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)