最近看了CGAN的论文,2014年的论文,短小精悍,CGAN可以用于图像修补,多模态识别,感觉很有意思。抽空会把CGAN的论文理解也放上来。
论文下载地址:Conditional Generative Adversarial Nets
先放入全部代码。来源:【Keras-CGAN】MNIST / CIFAR-10
代码中噪声Z和label、输入图片和label的combine机制和论文中不同,感觉没有达到论文中的效果,不过也很好。但是论文中的机制很复杂,入门用这个就能跑出较好的效果。这份代码的网络结构是多层感知器比较简单,没有用上卷积层,如果采用DCGAN 的结构可能效果会更好。
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
# build_generator
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
model.add(Reshape((28, 28, 1)))
model.summary()
noise = Input(shape=(100,)) # input 100,这里写成100不加逗号不行哟
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label)) # class, z dimension
model_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作为 model 的输入
print(model_input.shape)
img = model(model_input) # output (28,28,1)
generator = Model([noise, label], img)
# build_discriminator
model = Sequential()
model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=(28,28,1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作为 model 的输入
discriminator = Model([img, label], validity)
#compile model
optimizer = Adam(0.0002, 0.5)
# discriminator
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# The combined model (stacked generator and discriminator)
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise,label])
# For the combined model we will only train the generator
validity = discriminator([img,label])
discriminator.trainable = False
# Trains the generator to fool the discriminator
combined = Model([noise,label], validity)
combined.summary()
combined.compile(loss='binary_crossentropy',
optimizer=optimizer)
def sample_images(epoch):
r, c = 2, 5
noise = np.random.normal(0, 1, (r * c, 100))
sampled_labels = np.arange(0, 10).reshape(-1, 1)
gen_imgs = generator.predict([noise, sampled_labels])
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/mnist%d.png" % epoch)
plt.close()
batch_size = 32
sample_interval = 200
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data() # (60000,28,28)
# Rescale -1 to 1
X_train = X_train / 127.5 - 1. # tanh 的结果是 -1~1,所以这里 0-1 归一化后减1
X_train = np.expand_dims(X_train, axis=3) # (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(50001):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) # 0-60000 中随机抽
#imgs = X_train[idx]
imgs, labels = X_train[idx], y_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100))# 生成标准的高斯分布噪声
# Generate a batch of new images
gen_imgs = generator.predict([noise,labels])
# Train the discriminator
d_loss_real = discriminator.train_on_batch([imgs, labels], valid) #真实数据对应标签1
d_loss_fake = discriminator.train_on_batch([gen_imgs,labels], fake) #生成的数据对应标签0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
#noise = np.random.normal(0, 1, (batch_size, 100))
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
# Train the generator (to have the discriminator label samples as valid)
g_loss = combined.train_on_batch([noise, sampled_labels], valid)
# Plot the progress
if epoch % 200==0:
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
sample_images(epoch)
主要解读为GAN 的G网和D网的输入都添加条件信息的部分(add label)
1.G: G网的输入噪声z要结合label
model 定义了一个基于多层感知器的G网结构,然后
noise = Input(shape=(100,)) # input 100,这里写成100不加逗号不行哟
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label)) # class, z dimension
model_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作为 model 的输入
print(model_input.shape)
img = model(model_input) # output (28,28,1)
generator = Model([noise, label], img)
主要是embedding层的理解,可查看官方文档和相关博客。
- label_embedding 把“词汇表”大小为10的label(一共10个类别) 转换为100的向量维度,和noise维度一样。
- Flatten层将输入进行一维化
- Multiply层计算输入张量列表的(逐元素间的)乘积。将label和噪声Z结合。它接受一个张量的列表, 所有的张量必须有相同的输入尺寸, 然后返回一个张量(和输入张量尺寸相同)。因此,上一步把label转为和noise一样维度。
- img = model(model_input) # output (28,28,1),生成一个图片
- 由于以上的融合label的操作,使得G网的模型定义为:
generator = Model([noise, label], img) #定义了最终G网的结构
- Keras有两种类型的模型:序贯模型(Sequential)和函数式模型(Model)
- Model(inputs, outputs) generator = Model([noise, label], img)。 G网的输出还是img大小(28,28,1)
2.D: D网的输入img(真实or生成的图片)要结合label
img = Input(shape=(28,28,1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作为 model 的输入
discriminator = Model([img, label], validity) #定义了最终D网的结构
- np.prod()函数用来计算所有元素的乘积,所以label_embedding把“词汇表”大小为10的label(一共10个类别)转换成了28*28*1维,Flatten 转为一维。Multiply也是将label和输入的img结合。
- validity = model(model_input) 利用D网定义的model给输入的label和图片的结合进行打分,判断真假。
- discriminator = Model([img, label], validity) #定义了最终D网的结构
- 剩下的model complie, combined model,训练过程包括损失函数设计都和dcgan的设计一致,只是输入的部分时候要加上label
- sampled_labels = np.arange(0, 10).reshape(-1, 1)
- gen_imgs = generator.predict([noise, sampled_labels])
- 结果是生成了label为0-9的图片。
最后放一张,50000次迭代后的生成图片
由于这个代码的G网,D网结构没有采用卷积层,是多层感知器的结构(MLP),所以效果不太好,改成DCGAN 的结构可能效果会好很多。