GAN的基本总结和小型demo
关于GANS(Generative Adversarial Networks)
属于生成模型(generative models)
属于无监督学习(unsupervised learning)
在不给定目的值的情况下,学习所给数据的底层结构。
目前可生成最清晰的图像。
易于训练(不需要统计推断),只需要反向推断就能够获得梯度。
由于训练动态不稳定,难以优化。
基本不能做统计推断。
属于直接隐式密度模型,没有明确定义概率分布函数模型。
Generator和Discriminator
Discriminator
最大化被分类为属于真数据集的真数据输入
最小化被分类为属于真数据集的假数据输入
Generator
最大化被分类为属于真数据集的假数据输入
这意味着用于此网络的损耗/误差函数(loss/error函数)要最大化
经过许多步的训练,Generator和Discriminator都有足够的能力,均不能再进行改进,此时Generator就能生成真实的合成数据,而Discriminator已经无法区分。
训练GAN的基本步骤
1.采样噪声集和真实数据集,每个数据集具有大小m。
2.在这个数据上训练鉴别器。
3.采样具有大小m的不同噪声子集。
4.根据这个数据训练生成器。
5.从步骤1重复。
GAN的小型demo
1.导入相关库
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as t
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dataset
import numpy as np
# 绘制图像库
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
2.设置plt属性
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置大小
plt.rcParams['image.interpolation'] = 'nearest' # 设置插值模式
plt.rcParams['image.cmap'] = 'gray' # 设置颜色
3.图片显示
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # -1代表自动计算
sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) # np.ceil取整
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off') # 去掉坐标轴
ax.set_xticklabels([]) # 设置x标记为空
ax.set_yticklabels([]) # 设置y标记为空
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg, sqrtimg]))
return
4.采样函数
# 采样函数为自己定义的序列采样(即按顺序采样)
class Sampler(sampler.Sampler):
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
5.训练集和测试集的设置
NUM_TRAIN = 60000 # 训练集数量
NUM_VAL = 10000 # 测试集数量
NOISE_DIM = 96 # 噪声维度
batch_size = 128 # 批尺寸
mnist_train = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, sampler=Sampler(NUM_TRAIN, 0))
# 从0位置开始采样NUM_TRAIN个数
mnist_val = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_val = DataLoader(mnist_val, batch_size=batch_size, sampler=Sampler(NUM_VAL, NUM_TRAIN))
# 从NUM_TRAIN位置开始采样NUM_VAL个数
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
show_images(imgs) # 显示训练集图片
6.均匀噪声函数
def sample_noise(batch_size, dim):
"""
- 产生一个从-1 ~ 1的均匀噪声函数,形状为 [batch_size, dim].
参数:
- batch_size: 整型 提供生成的batch_size
- dim: 整型 提供生成维度
"""
temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim)*(-1)
return temp
7.平铺函数
# 平铺函数
class Flatten(nn.Module):
def forward(self, x):
n, c, h, w = x.size() # 读取为n,c,h,w
return x.view(n, -1) # 每张图片把c*h*w的值传入单向量用于后期处理
8.判别器
# 判别器 判断generator产生的图像是否为假,同时判断正确的图像是否为真
def discriminator():
model = nn.Sequential(
Flatten(),
nn.Linear(784, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 1)
)
return model
9.生成器
# 生成器
def generator(noise_dim=NOISE_DIM):
model = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Tanh(),
)
return model
10.损失函数
# GAN中指出的最大化最小化损失的算法
Bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake):
loss = None
# Batch size.
n = logits_real.size()
# 目标label,全部设置为1意味着判别器需要做到的是将正确的全识别为正确,错误的全识别为错误
true_labels = Variable(torch.ones(n))
real_image_loss = Bce_loss(logits_real, true_labels) # 识别正确的为正确
fake_image_loss = Bce_loss(logits_fake, 1 - true_labels) # 识别错误的为错误
loss = real_image_loss + fake_image_loss
return loss
def generator_loss(logits_fake):
n = logits_fake.size()
# 生成器的作用是将所有“假”的向真的(1)靠拢
true_labels = Variable(torch.ones(n))
# 计算生成器损失
loss = Bce_loss(logits_fake, true_labels)
return loss
11.Adam优化器
def get_optimizer(model):
"""
为模型构建并返回一个Adam优化器
learning rate 1e-3,
beta1=0.5, and beta2=0.999.
"""
# params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
# lr (float, optional) :学习率(默认: 1e-3)
# betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
return optimizer
12.GAN函数
def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,
batch_size=128, noise_size=96, num_epochs=10):
"""
训练GAN
- D, G: 分别为判别器和生成器
- D_solver, G_solver: D,G的优化器
- discriminator_loss, generator_loss: 计算D,G的损失
- show_every: 设置每show_every次显示样本
- batch_size: 每次训练在训练集中取batch_size个样本训练
- noise_size: 输入进生成器的噪声维度
- num_epochs: 训练迭代次数
"""
iter_count = 0
for epoch in range(num_epochs):
for x, _ in loader_train:
if len(x) != batch_size:
continue
D_solver.zero_grad()
real_data = Variable(x)
logits_real = D(2 * (real_data - 0.5))
g_fake_seed = Variable(sample_noise(batch_size, noise_size))
fake_images = G(g_fake_seed).detach()
logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
d_total_error = discriminator_loss(logits_real, logits_fake)
d_total_error.backward()
D_solver.step()
G_solver.zero_grad()
g_fake_seed = Variable(sample_noise(batch_size, noise_size))
fake_images = G(g_fake_seed)
gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
g_error = generator_loss(gen_logits_fake)
g_error.backward()
G_solver.step()
print(iter_count)
if iter_count % show_every == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error, g_error))
imgs_numpy = fake_images.data.cpu().numpy()
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
print("Completed!")
imgs_numpy = fake_images.data.cpu().numpy()
show_images(imgs_numpy[0:16])
plt.show()
print()
13.
# 创建判别器
D = discriminator()
# 创建生成器
G = generator()
# 创建D,G的优化器
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
# 运行GAN
run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)
最终结果为:
效果很差,后续我还需要对这个初步的GAN进行完善。