个人博客:http://www.chenjianqu.com/
原文链接:http://www.chenjianqu.com/show-54.html
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。论文《Generative Adversarial Nets》首次提出GAN。
GAN的思想
GAN由生成器G和判别器D组成。生成器G根据输入先验分布的随机向量(一般使用随机分布,论文里用的是高斯分布)得到符合数据集数据分布。判别器D判别输入数据来源于G还是真实数据。框架如下:
GAN的训练过程:刚开始G和D里面的参数随机初始化,第一步使用真实图片训练D,则D能轻易判断G生成的图片和真实图片。接着训练G,使得G生成的图片更加逼真,直到D无法判断G生成的图片和真实图片。接着训练D,使D能轻易判断真实图片。。。以此类推,最终G生成的图片和真实图片很相似。就好像论文里的例子,G就像货币的伪造者,D就像警察,G造假币,D识别假币,两者互相对抗。G造的假币越来越逼真,D识别的手段越来越高超。最终G生成的东西跟真的差不多。过程示意:
目标函数
从本质上看,GAN的训练目标就是使G恢复训练数据的数据分布。数据分布可以理解为数据的概率函数。
用更数学的表示,G将先验分布(比如高斯分布)P_prior(z)中的z通过G映射到x,即x的分布为P_G(x,theta),theta即G的参数。将生成器分布P_G(x,theta)和真实数据分布P_data(x)对比即可得到损失loss,如下图:
判别器D评估P_G(x)和已知数据分布P_data(x)的差异,即判别输入的x是来自真实分布还是生成器。
根据GAN的思想,其优化过程可以表示为以下公式:
将上式拆分,得到D和G的目标函数。
优化D,D的目标是将G(z)判断为真的概率D(G(z))尽可能小,即1-D(G(z))尽可能大,且将真实数据x判断为真的概率D(x)尽可能大。因此得公式如下:
优化G,G的目标是令D判断为真的概率尽可能大,即D(G(z))尽可能大。因此得公式如下:
实际训练G的时候,早期要求V的初始斜率大,因此需要替换V:
实际的数据是离散的,因此计算分布的期望是通过采样计算得到的。论文里提出,迭代的优化k步D和1步G。这可以让D保持在最优解附近,可得迭代优化参数的算法:
理论证明
以下的理论分析将证明GAN的目标函数有一个最优解p_g=p_data。
对于目标函数:
上式是求两个期望的相加,等价于:
我们想要在找到最优的D*,使得V(G,D)最大:
求V(G,D)等价与求以下公式最大:
上式中, P_data(x)是一个常量,表示x对应的概率分布中的值,这里设为a,P_G(x)也是如此,设为b。因此可以对上式进行求导,即可得到D*,过程如下:
代入D目标函数,得:
当且仅当p_g=p_data,C(G)取得最大值,此时C(G)=-log4,如下:
将D*代入V(G,D)的积分表达式,得:
分子分母同时除以2。再将1/2提出来,且和等于1,则:
上式的KL是KL散度。KL散度:相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量 [1] 。在在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。离散和连续随机变量的公式如下:
继续得到:
JSD是Jesen-Shannon散度。由于两个分布之间的Jensen-Shannon散度总是非负的,只有当它们相等时才为零,所以我们已经证明了C∗=−log(4)是C(G)的全局最小值,而唯一的解决方案是p_g = p_data,即,生成模型完美地复制了数据生成过程。
代码实现
论文里生成器使用全连接网络,使用relu和sigmoid激活函数;判别器也是全连接网络,使用maxout激活函数,同时应用了dropout。生成器的输入是符合高斯分布的随机向量,theta值根据交叉验证得到。在MNIST、TFD和CIFAR-10上面测试。
我这里的代码也是参照了网上开源的代码,判别器和生成器均使用relu和sigmoid激活函数。经过我的测试,发现每轮迭代的时候,生成器应该要比判别器训练更多,否则会发散,这跟论文里的描述相反。
网络的计算图如下:
代码如下:
定义参数
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import name_scope as namespace
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# 训练参数
num_steps = 20000
batch_size = 128
learning_rate = 0.0002
# 网络参数
image_dim = 784 # 28*28
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 100 # Noise data points
k=1
# 保存隐藏层的权重和偏置,用于变量共享
with namespace('var'):
weights = {
'gen_hidden1': tf.Variable(tf.truncated_normal([noise_dim, gen_hidden_dim],stddev=0.1)),
'gen_out': tf.Variable(tf.truncated_normal([gen_hidden_dim, image_dim],stddev=0.1)),
'disc_hidden1': tf.Variable(tf.truncated_normal([image_dim, disc_hidden_dim],stddev=0.1)),
'disc_out': tf.Variable(tf.truncated_normal([disc_hidden_dim, 1],stddev=0.1)),
}
biases = {
'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),
'gen_out': tf.Variable(tf.zeros([image_dim])),
'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),
'disc_out': tf.Variable(tf.zeros([1])),
}
定义网络和优化器
# 生成网络
def generator(x):
with namespace('gen_hidden1'):
hidden_layer = tf.matmul(x, weights['gen_hidden1'])
hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])
hidden_layer = tf.nn.relu(hidden_layer)
with namespace('gen_out'):
out_layer = tf.matmul(hidden_layer, weights['gen_out'])
out_layer = tf.add(out_layer, biases['gen_out'])
out_layer = tf.nn.sigmoid(out_layer)
return out_layer
# 判别网络
def discriminator(x):
with namespace('disc_hidden1'):
hidden_layer = tf.matmul(x, weights['disc_hidden1'])
hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
hidden_layer = tf.nn.relu(hidden_layer)
with namespace('disc_output'):
out_layer = tf.matmul(hidden_layer, weights['disc_out'])
out_layer = tf.add(out_layer, biases['disc_out'])
out_layer = tf.nn.sigmoid(out_layer)
return out_layer
# 网络输入
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')
disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input')
# 创建生成网络
with namespace('generator'):
gen_sample = generator(gen_input)
# 创建两个判别网络 (一个来自噪声输入, 一个来自生成的样本)
with namespace('discriminator'):
with namespace('discriminator_real'):
disc_real = discriminator(disc_input)
with namespace('discriminator_fake'):
disc_fake = discriminator(gen_sample)
with namespace('loss'):
# 定义损失函数
gen_loss = -tf.reduce_mean(tf.log(disc_fake))
disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))
#将变量的损失值写入Loss
tf.summary.scalar('gen_loss', gen_loss)
tf.summary.scalar('disc_loss', disc_loss)
merged_summary = tf.summary.merge_all()
with namespace('train'):
# 定义优化器
optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)
# 训练每个优化器的变量
# 生成网络变量
gen_vars = [weights['gen_hidden1'], weights['gen_out'],biases['gen_hidden1'], biases['gen_out']]
# 判别网络变量
disc_vars = [weights['disc_hidden1'], weights['disc_out'],biases['disc_hidden1'], biases['disc_out']]
# 最小损失函数
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
# 初始化变量
init = tf.global_variables_initializer()
训练网络
def getData(batch_size=128):
batch_x, _ = mnist.train.next_batch(batch_size)# 准备数据
z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])# 产生噪声给生成网络
return batch_x,z
# 开始训练
with tf.Session() as sess:
sess.run(init)
writer=tf.summary.FileWriter('D:/Jupyter/GAN/mnist_gan_train_log/log',sess.graph)
saver=tf.train.Saver()
for i in range(1, num_steps+1):
for j in range(k):
x,z=getData()
_,dl = sess.run([train_disc, disc_loss], feed_dict={disc_input: x, gen_input: z})
x,z=getData()
_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
x,z=getData()
_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
x,z=getData()
_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
x,z=getData()
summary,g = sess.run([merged_summary,gen_sample], feed_dict={disc_input:x,gen_input:z})
writer.add_summary(summary,i)#写summary和i到文件
if i % 1000 == 0 or i == 1:
print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
# 使用生成器网络从噪声生成图像
f, a = plt.subplots(4, 10, figsize=(10, 4))
for i in range(10):
# 噪声输入.
z = np.random.uniform(-1., 1., size=[4, noise_dim])
g = sess.run([gen_sample], feed_dict={gen_input: z})
g = np.reshape(g, newshape=(4, 28, 28, 1))
# 将原来黑底白字转换成白底黑字,更好的显示
g = -1 * (g - 1)
for j in range(4):
# 从噪音中生成图像。 扩展到3个通道,用于matplotlib
img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))
a[j][i].imshow(img)
plt.savefig('test.png')#保存图片
#f.show()
#plt.draw()
#plt.waitforbuttonpress()
生成的结果:
训练过程的损失值:
从损失曲线可以看到,GAN的训练过程是不稳定的。
参考文献
[1]Ian J. Goodfellow,etc.Generative Adversarial Nets.2014.arXiv:1406.2661v1
[2]小白的成长. GAN之V(D,G)函数. https://blog.csdn.net/qq_42413820/article/details/80673857. 2018-06-13