GAN代码解析(tensorflow实现)_手写数字图片生成
基于py3.0支持中文名方法, 如果报错请把中文方法名,改为英文的
#coding:utf-8
# MNIST数据集
# MNIST数据集的官网是Yann LeCun’s website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。
# 你可以下载这份代码,然后用下面的代码导入到你的项目里面,也可以直接复制粘贴到你的代码文件里面。
# import input_data
# mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。
#
# 每一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:
#
# 我们把这个数组展开成一个向量,长度是 28x28 = 784
#
# 因此,在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,
# 第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。
#
# 相对应的MNIST数据集的标签是介于0到9的数字,用来描述给定图片里表示的数字.因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。
# Dropout中隐层节点的忽略比例主要作用在隐层节点,是按照一定比例,随机地使部分隐层节点失效,并且该比例与最后通过模型平均来求得最后的预测值也有一定的关系。
# DAE中加噪比例作用于输入层,是按照一定比例,对每个网络的输入数据加入噪声,使得自动编码器通过学习获得真正的没有被噪声污染过的输入。这种加入噪声的思想,并不需要进行模型平均。
# tf.slice()介绍
# 函数:tf.slice(inputs, begin, size, name)
# 作用:从列表、数组、张量等对象中抽取一部分数据
# begin和size是两个多维列表,他们共同决定了要抽取的数据的开始和结束位置
# begin表示从inputs的哪几个维度上的哪个元素开始抽取
# size表示在inputs的各个维度上抽取的元素个数
# 若begin[]或size[]中出现-1,表示抽取对应维度上的所有元素
# import tensorflow as tf
# import numpy as np
# x=[[1,2,3],[4,5,6]]
# with tf.Session() as sess:
# begin = [0,1] # 从x[0,1],即元素2开始抽取
# size = [2,1] # 从x[0,1]开始,对x的第一个维度(行)抽取2个元素,在对x的第二个维度(列)抽取1个元素
# print sess.run(tf.slice(x,begin,size)) # 输出[[2 5]]
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #导入图片数据
import numpy as np
from skimage.io import imsave #读图片
import os
import shutil #图片处理
import sys
img_height = 28
img_width = 28
img_size = img_height * img_width #图片像素28X28 ,做全链接拉伸784
to_train = True #训练开关
to_restore = False #保存模型开关
output_path = "./output" #保存模型路径
# 总迭代次数500
max_epoch = 500
#隐层神经元的个数
h1_size = 150 #隐层神经元第一层神经元的个数
h2_size = 300 #隐层神经元第2层神经元的个数
z_size = 100 #输入的噪音点(输入也为100)
batch_size = 256 #batch_size 一次256张图片。 判别模型有512张(真假各一半)
# generate (model 1) 这里用的全连接
def 生成模型(z_prior):#build_generator #初始化W,b参数 ,刚开始输入层是100,hide1 150,hide2 是300
w1 = tf.Variable(tf.truncated_normal(shape=[z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32) #truncated_normal随机生成函数产生正太分布的W,这是一个截断的产生正太分布的函数,就是说产生正太分布的值如果与均值的差值大于两倍的标准差,那就重新生成。
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
w2 = tf.Variable(tf.truncated_normal(shape=[h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
w3 = tf.Variable(tf.truncated_normal(shape=[h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32) #b3的大小相当于图片拉直后的像素点个数784
h3 = tf.matmul(h2, w3) + b3
x_generate = tf.nn.tanh(h3) #生成784维的数
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params
# discriminator (model 2)
# x_data是真是值
# keep_prob 隐层节点的忽略比例 ,dropout 比例
def 判别模型(x_data, x_generated, keep_prob):#build_discriminator
# tf.concat
x_in = tf.concat([x_data, x_generated], 0)#两倍于生成模型数量的图片,因为要参杂真是图片
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32) #784X300
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)#W1 784X300
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)#300X150
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)#150
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3#是输出一个数判别生成真是图片的概率
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))#切片,代表的是真实的数据y_data ,[batch_size, -1]256行,-1代表所有列,生成的256
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))#原图形和生成图形一一对应,[-1, -1]取剩下的所有行所有列
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params
# 保存图片进度及保存图片算法
#show_result(x_gen_val, "output_random/random_sample{0}.jpg".format(i))
# grid_pad=5 没有填充的地方用5去填
def 展示结果保存(batch_res, fname, grid_size=(8, 8), grid_pad=5):#show_result
#数字转化为图片
batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5 #除以2 reshape 784维的常量 变为28x28,分别作为图片像素的宽高,保存下来.如下,做图像还原时要重新加上均值0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
imsave(fname, img_grid) #保存图
def 开始训练():
# load data(mnist手写数据集)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data") #真实值256X784 ,一次传的数据
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") #输入
keep_prob = tf.placeholder(tf.float32, name="keep_prob")#dropout 比例0.7
global_step = tf.Variable(0, name="global_step", trainable=False)#总共迭代多少步?反向更新多少次
# 创建生成模型
x_generated, g_params = 生成模型(z_prior)
# 创建判别模型
y_data, y_generated, d_params = 判别模型(x_data, x_generated, keep_prob)
# 损失函数的设置
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated)) #整个数据的交叉熵
g_loss = - tf.log(y_generated)#生成器的损失函数,计算的就是生产数据的交叉熵
optimizer = tf.train.AdamOptimizer(0.0001)
# 两个模型的优化函数
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params)
init = tf.initialize_all_variables()
# init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化图
sess.run(init)
#tensorflow模型持久化
if to_restore:
chkpt_fname = tf.train.latest_checkpoint(output_path)
print(chkpt_fname)
saver.restore(sess, chkpt_fname)
else:
if os.path.exists(output_path):
shutil.rmtree(output_path)#如果存在则删除
os.mkdir(output_path)#如果不存在则创建
z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)#接着从0-1均匀分布中抽取了z(至于为什么用这个分布,可以去查看一个概率论,几乎所有重要的概率分布都可以从均匀分布Uniform(0,1)中生成出来)
steps = 60000 / batch_size #训练集中图片的数量60000,
for i in range(sess.run(global_step), max_epoch):
for j in np.arange(steps):
# for j in range(steps):
print("epoch:%s, iter:%s" % (i, j))
# 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step
x_value, _ = mnist.train.next_batch(batch_size)
x_value = 2 * x_value.astype(np.float32) - 1#python是从0开始的
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 执行生成
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 执行判别
if j % 1 == 0: #每个bitch_size打印一次
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
展示结果保存(x_gen_val, "./output_sample/sample{0}.jpg".format(i))
#以下三句可以省略
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
展示结果保存(x_gen_val, "./output_random/random_sample{0}.jpg".format(i))
sess.run(tf.assign(global_step, i + 1)) #tf.assign(A, new_number): 这个函数的功能主要是把A的值变为new_number,赋值作用
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)
# def test():
# z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
# x_generated, _ = build_generator(z_prior)
# chkpt_fname = tf.train.latest_checkpoint(output_path)
#
# init = tf.initialize_all_variables()
# sess = tf.Session()
# saver = tf.train.Saver()
# sess.run(init)
# saver.restore(sess, chkpt_fname)
# z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
# show_result(x_gen_val, "output/test_result.jpg")
if __name__ == '__main__':
# if to_train:
# train()
# else:
# test()
开始训练()
random_sample314.jpg