对抗神经网络学习(五)——infoGAN生成宽窄不一,高低各异的服装影像(tensorflow实现)

一、背景

前一阶段比较忙,很久没有继续做GAN的实验了。近期终于抽空做完了infoGAN,个人认为infoGAN是对GAN的更进一步改进,由于GAN是输入的随机生成噪声,所以生成的图像也是随机的,而infoGAN想要生成的是指定特征的图像,因此infoGAN对GAN的随机输入加了约束,这是其最大的改进之处。infoGAN是16年6月份由Xi Chen等人提出的一种模型。本实验主要利用infoGAN生成宽窄不一,高低各异的服装影像。

本实验以fashion-mnist数据集为例,用尽可能少的代码实现infoGAN。

[1]文章链接:https://arxiv.org/abs/1606.03657 

二、infoGAN原理

infoGAN的原理网上介绍的也比较多,这里不再过多叙述。推荐一篇对原理讲解比较清楚的文章:

[2]InfoGAN介绍

从文章中作者的介绍来看,infoGAN最主要的贡献是引入了互信息(mutual information),通过最大化(maximizing)GAN噪声变量子集和观测值之间的互信息,以实现对学习过程的可解译性。作者将实验应用于MNIST, CelebA,SVHN数据集,结果表明引入互信息的模型都能够取得更好的效果。

In this paper, we present a simple modification to the generative adversarial network objective that encourages it to learn interpretable and meaningful representations. We do so by maximizing the mutual information between a fixed small subset of the GAN’s noise variables and the observations, which turns out to be relatively straightforward. Despite its simplicity, we found our method to be surprisingly effective: it was able to discover highly semantic and meaningful hidden representations on a number of image datasets: digits (MNIST), faces (CelebA), and house numbers (SVHN). The quality of our unsupervised disentangled representation matches previous works that made use of supervised label information [5–9]. These results suggest that generative modelling augmented with a mutual information cost could be a fruitful approach for learning disentangled representations.

通俗一点来说,GAN模型在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,我们无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3,我们对这些一无所知,这从一点程度上限制了我们对GAN的使用[2]”。而作者对infoGAN的改进便针对这个问题,“Info代表互信息,它表示生成数据x与隐藏编码c之间关联程度的大小,为了使的x与c之间关联密切,所以我们需要最大化互信息的值,据此对原始GAN模型的值函数做了一点修改,相当于加了一个互信息的正则化项。是一个超参,通过之后的实验选择了一个最优值1。[2]”

后文中作者也提到了,infoGAN主要针对GAN的问题进行了改进,采用的模型基础仍是DCGAN,因此infoGAN的具体实现过程可以参照DCGAN。文章中作者进行了很多公式的推导,得出的最终结论为:

Hence, InfoGAN is defined as the following minimax game with a variational regularization of mutual information and a hyperparameter λ: 

                                                     \underset{G,Q}{min}\underset{D}{max}V_{infoGAN} = V(D,G)-\lambda L_{I}(G,Q)

不用理解上述公式也没关系,我们只要知道了infoGAN主要做了哪方面的改进就行。关于infoGAN的实现代码,网上也比较多,下面给出几个比较好的代码:

[3]https://github.com/openai/InfoGAN

[4]https://github.com/AndyHsiao26/InfoGAN-tensorflow

[5]https://github.com/hwalsuklee/tensorflow-generative-model-collections

本实验的目的就在于用最少,最简单的代码实现infoGAN,主要参考了[5]的实现过程,并在原代码的基础上进行了少量改进,只不过这次保留了原代码中的类,而我之前的GAN实现都是尽量写成函数的形式。

三、infoGAN实现

1.数据准备

这次的实验数据采用的是fashion-mnist数据集,顾名思义,该数据集与mnist数据集的格式相同,只不过该数据集是10类服饰,但图像仍是28*28的灰度图:

该数据集的下载地址为:https://github.com/zalandoresearch/fashion-mnist

打开上述地址,找到下面的数据集,点击download即可开始下载:

下载好的数据集,我们放在'./data/fashion-mnist/'文件夹下,这样就准备好了数据,不用解压,下面即可开始实验部分:

2.数据操作函数准备(utils.py)

这一部分主要准备一些数据操作函数,包括数据的加载,图像的存储,文件夹的建立等函数。这部分函数DCGAN中也能用到,因此直接将其拷贝过来,进行后续的使用。这里将该文件命名为utils.py,直接给出该文件的代码:

"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
from __future__ import division
import scipy.misc
import numpy as np
import os, gzip

import tensorflow as tf
import tensorflow.contrib.slim as slim


def load_mnist(dataset_name):
    data_dir = 'data/' + dataset_name
    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)

    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    return X / 255., y_vec


def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir


def show_all_variables():
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)


def save_images(images, size, image_path):
    return imsave(inverse_transform(images), size, image_path)


def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    if (images.shape[3] in (3,4)):
        c = images.shape[3]
        img = np.zeros((h * size[0], w * size[1], c))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w, :] = image
        return img
    elif images.shape[3]==1:
        img = np.zeros((h * size[0], w * size[1]))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
        return img
    else:
        raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')


def imsave(images, size, path):
    image = np.squeeze(merge(images, size))
    return scipy.misc.imsave(path, image)


def inverse_transform(images):
    return (images+1.)/2.

3.图层函数(layers.py)

这一部分主要是编写图层函数,并将这部分函数保存到layers.py文件当中,这里直接给出layers.py的代码:

"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
from utils import *

if "concat_v2" in dir(tf):
    def concat(tensors, axis, *args, **kwargs):
        return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
    def concat(tensors, axis, *args, **kwargs):
        return tf.concat(tensors, axis, *args, **kwargs)

def bn(x, is_training, scope):
    return tf.contrib.layers.batch_norm(x,
                                        decay=0.9,
                                        updates_collections=None,
                                        epsilon=1e-5,
                                        scale=True,
                                        is_training=is_training,
                                        scope=scope)

def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

        return conv

def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))

        try:
            deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])

        # Support for verisons of TensorFlow before 0.7.0
        except AttributeError:
            deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

        if with_w:
            return deconv, w, biases
        else:
            return deconv

def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)

def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
    shape = input_.get_shape().as_list()

    with tf.variable_scope(scope or "Linear"):
        matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                 tf.random_normal_initializer(stddev=stddev))
        bias = tf.get_variable("bias", [output_size],
        initializer=tf.constant_initializer(bias_start))
        if with_w:
            return tf.matmul(input_, matrix) + bias, matrix, bias
        else:
            return tf.matmul(input_, matrix) + bias

4.infoGAN的模型实现(infoGAN.py)

这一部分主要编写infoGAN的实现,包括参数,生成器和判别器,以及loss函数,这部分的代码比较长,回头我在对其进行详细的解释。最终infoGAN的代码为:

from __future__ import division
import time

from layers import *
from utils import *


class infoGAN(object):
    model_name = "infoGAN"     # name for checkpoint

    def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir, SUPERVISED=True):
        self.sess = sess
        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        self.result_dir = result_dir
        self.log_dir = log_dir
        self.epoch = epoch
        self.batch_size = batch_size

        if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
            # parameters
            self.input_height = 28
            self.input_width = 28
            self.output_height = 28
            self.output_width = 28

            self.z_dim = z_dim         # dimension of noise-vector
            self.y_dim = 12         # dimension of code-vector (label+two features)
            self.c_dim = 1

            self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code

            # train
            self.learning_rate = 0.0002
            self.beta1 = 0.5

            # test
            self.sample_num = 64  # number of generated images to be saved

            # code
            self.len_discrete_code = 10  # categorical distribution (i.e. label)
            self.len_continuous_code = 2  # gaussian distribution (e.g. rotation, thickness)

            # load mnist
            self.data_X, self.data_y = load_mnist(self.dataset_name)

            # get number of batches for a single epoch
            self.num_batches = len(self.data_X) // self.batch_size
        else:
            raise NotImplementedError

    def classifier(self, x, is_training=True, reuse=False):
        # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
        # Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’
        # All layers except the last two layers are shared by discriminator
        # Number of nodes in the last layer is reduced by half. It gives better results.
        with tf.variable_scope("classifier", reuse=reuse):

            net = lrelu(bn(linear(x, 64, scope='c_fc1'), is_training=is_training, scope='c_bn1'))
            out_logit = linear(net, self.y_dim, scope='c_fc2')
            out = tf.nn.softmax(out_logit)

            return out, out_logit

    def discriminator(self, x, is_training=True, reuse=False):
        # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
        # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
        with tf.variable_scope("discriminator", reuse=reuse):

            net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
            net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
            net = tf.reshape(net, [self.batch_size, -1])
            net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
            out_logit = linear(net, 1, scope='d_fc4')
            out = tf.nn.sigmoid(out_logit)

            return out, out_logit, net

    def generator(self, z, y, is_training=True, reuse=False):
        # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
        # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
        with tf.variable_scope("generator", reuse=reuse):

            # merge noise and code
            z = concat([z, y], 1)

            net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
            net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
            net = tf.reshape(net, [self.batch_size, 7, 7, 128])
            net = tf.nn.relu(
                bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
                   scope='g_bn3'))

            out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))

            return out

    def build_model(self):
        # some parameters
        image_dims = [self.input_height, self.input_width, self.c_dim]
        bs = self.batch_size

        """ Graph Input """
        # images
        self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')

        # labels
        self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')

        # noises
        self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')

        """ Loss Function """
        ## 1. GAN Loss
        # output of D for real images
        D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)

        # output of D for fake images
        G = self.generator(self.z, self.y, is_training=True, reuse=False)
        D_fake, D_fake_logits, input4classifier_fake = self.discriminator(G, is_training=True, reuse=True)

        # get loss for discriminator
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))

        self.d_loss = d_loss_real + d_loss_fake

        # get loss for generator
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))

        ## 2. Information Loss
        code_fake, code_logit_fake = self.classifier(input4classifier_fake, is_training=True, reuse=False)

        # discrete code : categorical
        disc_code_est = code_logit_fake[:, :self.len_discrete_code]
        disc_code_tg = self.y[:, :self.len_discrete_code]
        q_disc_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_code_est, labels=disc_code_tg))

        # continuous code : gaussian
        cont_code_est = code_logit_fake[:, self.len_discrete_code:]
        cont_code_tg = self.y[:, self.len_discrete_code:]
        q_cont_loss = tf.reduce_mean(tf.reduce_sum(tf.square(cont_code_tg - cont_code_est), axis=1))

        # get information loss
        self.q_loss = q_disc_loss + q_cont_loss

        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]
        q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]

        # optimizers
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
                .minimize(self.d_loss, var_list=d_vars)
            self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
                .minimize(self.g_loss, var_list=g_vars)
            self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
                .minimize(self.q_loss, var_list=q_vars)

        """" Testing """
        # for test
        self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)

        """ Summary """
        d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
        d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
        d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)

        q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)
        q_disc_sum = tf.summary.scalar("q_disc_loss", q_disc_loss)
        q_cont_sum = tf.summary.scalar("q_cont_loss", q_cont_loss)

        # final summary operations
        self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
        self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
        self.q_sum = tf.summary.merge([q_loss_sum, q_disc_sum, q_cont_sum])

    def train(self):

        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
        self.test_labels = self.data_y[0:self.batch_size]
        self.test_codes = np.concatenate((self.test_labels, np.zeros([self.batch_size, self.len_continuous_code])),
                                           axis=1)

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]

                # generate code
                if self.SUPERVISED == True:
                    batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
                else:
                    batch_labels = np.random.multinomial(1,
                                                         self.len_discrete_code * [float(1.0 / self.len_discrete_code)],
                                                         size=[self.batch_size])

                batch_codes = np.concatenate((batch_labels, np.random.uniform(-1, 1, size=(self.batch_size, 2))),
                                             axis=1)

                batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)

                # update D network
                _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
                                                       feed_dict={self.inputs: batch_images, self.y: batch_codes,
                                                                  self.z: batch_z})
                self.writer.add_summary(summary_str, counter)

                # update G and Q network
                _, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run(
                    [self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss],
                    feed_dict={self.inputs: batch_images, self.z: batch_z, self.y: batch_codes})
                self.writer.add_summary(summary_str_g, counter)
                self.writer.add_summary(summary_str_q, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))

                # save training results for every 300 steps
                if np.mod(counter, 300) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={self.z: self.sample_z, self.y: self.test_codes})
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                                './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
                                    epoch, idx))

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)

    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random noise, random discrete code, fixed continuous code """
        y = np.random.choice(self.len_discrete_code, self.batch_size)
        y_one_hot = np.zeros((self.batch_size, self.y_dim))
        y_one_hot[np.arange(self.batch_size), y] = 1

        z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

        """ specified condition, random noise """
        n_styles = 10  # must be less than or equal to self.batch_size

        np.random.seed()
        si = np.random.choice(self.batch_size, n_styles)

        for l in range(self.len_discrete_code):
            y = np.zeros(self.batch_size, dtype=np.int64) + l
            y_one_hot = np.zeros((self.batch_size, self.y_dim))
            y_one_hot[np.arange(self.batch_size), y] = 1

            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
            # save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
            #             check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)

            samples = samples[si, :, :, :]

            if l == 0:
                all_samples = samples
            else:
                all_samples = np.concatenate((all_samples, samples), axis=0)

        """ save merged images to check style-consistency """
        canvas = np.zeros_like(all_samples)
        for s in range(n_styles):
            for c in range(self.len_discrete_code):
                canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]

        save_images(canvas, [n_styles, self.len_discrete_code],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')

        """ fixed noise """
        assert self.len_continuous_code == 2

        c1 = np.linspace(-1, 1, image_frame_dim)
        c2 = np.linspace(-1, 1, image_frame_dim)
        xv, yv = np.meshgrid(c1, c2)
        xv = xv[:image_frame_dim,:image_frame_dim]
        yv = yv[:image_frame_dim, :image_frame_dim]

        c1 = xv.flatten()
        c2 = yv.flatten()

        z_fixed = np.zeros([self.batch_size, self.z_dim])

        for l in range(self.len_discrete_code):
            y = np.zeros(self.batch_size, dtype=np.int64) + l
            y_one_hot = np.zeros((self.batch_size, self.y_dim))
            y_one_hot[np.arange(self.batch_size), y] = 1

            y_one_hot[np.arange(image_frame_dim*image_frame_dim), self.len_discrete_code] = c1
            y_one_hot[np.arange(image_frame_dim*image_frame_dim), self.len_discrete_code+1] = c2

            samples = self.sess.run(self.fake_images,
                                    feed_dict={ self.z: z_fixed, self.y: y_one_hot})

            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                        check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_c1c2_%d.png' % l)

    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
            self.model_name, self.dataset_name,
            self.batch_size, self.z_dim)

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0

5.main函数实现模型训练(main.py)

最终就是训练模型过程了,这了main.py文件的主要工作就是设置参数,创建infoGAN模型,进行模型的训练,main.py文件的主要代码为:

from infoGAN import infoGAN

import tensorflow as tf

"""main"""
def main():
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # declare instance for GAN

        infogan = infoGAN(sess,
                      epoch=20,
                      batch_size=64,
                      z_dim=62,
                      dataset_name='fashion-mnist',
                      checkpoint_dir='checkpoint',
                      result_dir='results',
                      log_dir='logs')

        # build graph
        infogan.build_model()

        # show network architecture
        # show_all_variables()

        # launch the graph in a session
        infogan.train()
        print(" [*] Training finished!")

        # visualize learned generator
        infogan.visualize_results(20-1)
        print(" [*] Testing finished!")

if __name__ == '__main__':
    main()

6.模型的执行

模型的执行只需运行main.py文件即可,然后等模型训练完毕即可查看模型的结果。

四、实验结果

实验一共设置了20个epoch,训练效果比较好。这里直接展示关于衣服的训练结果。

当epoch=1时候,实验的结果为:

当epoch=5的时候,实验的结果为:

当epoch=10的时候,实验的结果为:

当epoch=20的时候,实验的结果为:

五、分析

1.实验的结果还是比较好的,即使epoch=1,也能够比较清晰的看出最后的生成服饰图像,而且也能够明显的看到生成的服饰宽窄各异的衣服。

2.所有文件的结构为:

-- data            (原始数据集的文件夹)
    |------ fashion-mnist
                |------ t10k-images-idx3-ubyte.gz
                |------ t10k-labels-idx1-ubyte.gz
                |------ train-images-idx3-ubyte.gz
                |------ train-labels-idx1-ubyte.gz
-- utils.py
    {
    import...

    def load_mnist(dataset_name):...

    def check_folder(log_dir):...

    def show_all_variables():...

    def save_images(images, size, image_path):...

    def merge(images, size):...

    def imsave(images, size, path):...

    def inverse_transform(images):...
    }
-- layers.py
    {
    import...

    def bn(x, is_training, scope):...

    def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):...

    def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False):...

    def lrelu(x, leak=0.2, name="lrelu"):...

    def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):...
    }
-- infoGAN.py
    {
    import...

    class infoGAN(object):
        def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir, SUPERVISED=True):...
        def classifier(self, x, is_training=True, reuse=False):...
        def discriminator(self, x, is_training=True, reuse=False):...
        def generator(self, z, y, is_training=True, reuse=False):...
        def build_model(self):...
        def train(self):...
        def visualize_results(self, epoch):...

        def model_dir(self):...
        def save(self, checkpoint_dir, step):...
        def load(self, checkpoint_dir):...
    }
-- main.py
    {
    import ...
 
    def main():...

    if __name__ == '__main__':...
    }   

猜你喜欢

转载自blog.csdn.net/z704630835/article/details/83211086