对抗神经网络(GAN)生成MNIST手写数字图片

前言

看过金庸武侠小说的人应当不会对老顽童周伯通感到陌生,他自创了一门武功叫做左右互博术,左手与右手相互较量,无需陪练,经年累月便能练成绝世武功。而我们今天所要讲的对抗神经网络的思想,和左右互博术有着异曲同工之妙。
我们先来看一张图片:

图片中的这些人物有的可能看着有点奇怪,但他们都是通过训练对抗神经网络然后自动生成出来的图片。简而言之,我们可以通过对抗神经网络来创作一些并不存在的事物的图片。
我们今天就来简单介绍一下对抗神经网络的基础原理,然后通过训练一个对抗神经网络来自动生成MNIST手写数字的图片。

什么是GAN

传统的神经网络一般都是用来做图像分类的,即输入一张图片,用来判断它的类别。假设我们现在有一个神经网络D,它被用来识别一张图片是真是图片还是由机器自动生成的图片。另外我们还有一个神经网络G,它通过训练学习真实图像的分布,然后接受给定输入,用来生成类似真实图像的图片。那么对于这两个网络G和D,我们的目标便是:使得生成器G生成的图片能够尽量接近真实得以骗过判别器D,而D也要尽最大努力去识别图片的真假。这是双方的一个博弈过程,就像前面所提到的左右互搏术一样。随着时间的推移,G和D不断对抗,最终两个网络达到了一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。
下面一张图片是我在知乎上找到的关于GAN的直观理解:

对于GAN更加直观的理解可以用一个例子来说明:造假币的团伙相当于生成器,他们想通过伪造金钱来骗过银行,使得假币能够正常交易,而银行相当于判别器,需要判断进来的钱是真钱还是假币。因此假币团伙的目的是要造出银行识别不出的假币而骗过银行,银行则是要想办法准确地识别出假币。
现在我们来总结一下以上分析,给定真=1,假=0,那么有:

  • 对于给定的真实图片(real image),判别器要为其打上标签1;
  • 对于给定的生成图片(fake image),判别器要为其打上标签0;
  • 对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签1。

以上就是一个关于对于对抗神经网络的简单介绍,由于我也才刚接触到这个东西,所以这里不去深究理论,直接通过代码来加深一下认识。

代码运行环境

  • TensorFlow 1.0
  • Python 3
  • Github地址 Github
    需要说明的是,这个代码并非我所写,而是我在Github上面找到的,而且需要修改几处小的地方才能完整的跑起来,这个我后面会提到。

  • 训练环境 如果你的电脑上有GPU的话,完全可以用自己的电脑进行训练。如果你电脑上没有GPU,那么可以参考我之前的一篇文章:褥资本主义羊毛--Google免费GPU使用小记 ,使用Google提供的免费GPU进行训练,当然前提是要能翻墙。

代码分析

1.参数定义

class MnistModel:

    def __init__(self):
        # mnist测试集
        self.mnist = input_data.read_data_sets('mnist/', one_hot=True)
        # 图片大小
        self.img_size = self.mnist.train.images[0].shape[0]
        # 每步训练使用图片数量
        self.batch_size = 64
        # 图片分块数量
        self.chunk_size = self.mnist.train.num_examples // self.batch_size
        # 训练循环次数
        self.epoch_size = 300
        # 抽取样本数
        self.sample_size = 25
        # 生成器判别器隐含层数量
        self.units_size = 128
        # 学习率
        self.learning_rate = 0.001
        # 平滑参数
        self.smooth = 0.1

2.真实图片与混淆图片

        import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

        # 真实图片与混淆图片
        # 不确定输入图片数量 用None
        real_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='real_imgs')
        fake_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='fake_imgs')

3.生成器

    import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

    def generator_graph(fake_imgs, units_size, out_size, alpha=0.01):
        # 生成器与判别器属于两个网络 定义不同scope
        with tf.variable_scope('generator'):
            # 构建一个全连接层
            layer = tf.layers.dense(fake_imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # dropout 防止过拟合
            drop = tf.layers.dropout(relu, rate=0.2)
            # logits
            # out_size应为真实图片size大小
            logits = tf.layers.dense(drop, out_size)
            # 激活函数 将向量值限定在某个区间 与 真实图片向量类似
            # 这里tanh的效果比sigmoid好一些
            # 输出范围(-1, 1) 采用sigmoid则为[0, 1]
            outputs = tf.tanh(logits)
            return logits, outputs

我们使用了一个采用Leaky ReLU作为激活函数的隐层,使用dropout防止过拟合,并在输出层加入tanh激活函数。

4.判别器

    import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

    def discriminator_graph(imgs, units_size, alpha=0.01, reuse=False):
        with tf.variable_scope('discriminator', reuse=reuse):
            # 构建全连接层
            layer = tf.layers.dense(imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # logits
            # 判断图片真假 out_size直接限定为1
            logits = tf.layers.dense(relu, 1)
            # 激活函数
            outputs = tf.sigmoid(logits)
            return logits, outputs

判别器接收一张图片,并判断它的真假,同样隐层使用了Leaky ReLU,输出层为1个结点,输出为1的概率。

5.损失

    import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

    def loss_graph(real_logits, fake_logits, smooth):
        # 生成器图片loss
        # 生成器希望判别器判断出来的标签为1
        gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits) * (1 - smooth)))
        # 判别器识别生成器图片loss
        # 判别器希望识别出来的标签为0
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
        # 判别器识别真实图片loss
        # 判别器希望识别出来的标签为1
        real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits) * (1 - smooth)))
        # 判别器总loss
        dis_loss = tf.add(fake_loss, real_loss)
        return gen_loss, fake_loss, real_loss, dis_loss

理解这一段代码我们需要参考前面的这部分内容:

  • 对于给定的真实图片(real image),判别器要为其打上标签1;
  • 对于给定的生成图片(fake image),判别器要为其打上标签0;
  • 对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签1。

代码中 real_loss 变量的计算即代表着上面的第一句话,real_logits 代表判别器D对真实图片的分类结果,我们希望让他尽量接近于1。需要指出的是,这里我们并不是让他完全接近1,而是让他尽量接近 (1 - smooth),称smooth为平滑值,这是一种防止过拟合的方式,能让模型具有更好的鲁棒性。
代码中 fake_loss 变量的计算即代表着上面的第二句话,fake_logits 代表判别器D对生成图片的分类结果,我们希望让他尽量接近于0。
代码中的 dis_loss 变量是 real_loss 与 fake_loss 的和,即为判别器D的总loss,即D需要最小化的量。
代码中的 gen_loss 变量的计算即代表着上面的第三句话,生成器G希望判别器D对生成图片的分类结果尽量接近于1,这里同样采用了前面提到的平滑处理。gen_loss即为生成器G的loss,即G需要最小化的量。

6.优化

    import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

    def optimizer_graph(gen_loss, dis_loss, learning_rate):
        # 所有定义变量
        train_vars = tf.trainable_variables()
        # 生成器变量
        gen_vars = [var for var in train_vars if var.name.startswith('generator')]
        # 判别器变量
        dis_vars = [var for var in train_vars if var.name.startswith('discriminator')]
        # optimizer
        # 生成器与判别器作为两个网络需要分别优化
        gen_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(gen_loss, var_list=gen_vars)
        dis_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(dis_loss, var_list=dis_vars)
        return gen_optimizer, dis_optimizer

这里定义了两个优化器 gen_optimizer 和 dis_optimizer ,分别为生成器G和判别器D的优化器,作用即最小化各自对应的loss。需要指出的是,G和D两个网络是分开训练的,因此这里先获取到了每个网络各自对应的参数然后再进行优化。

7.开始训练

        import os   # 不加这一句博客的网页不能识别代码类型,不会高亮显示

        # 开始训练
        saver = tf.train.Saver()
        step = 0
        # 指定占用GPU比例
        # tensorflow默认占用全部GPU显存 防止在机器显存被其他程序占用过多时可能在启动时报错
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())
            for epoch in range(self.epoch_size):
                for _ in range(self.chunk_size):
                    batch_imgs, _ = self.mnist.train.next_batch(self.batch_size)
                    batch_imgs = batch_imgs * 2 - 1
                    # generator的输入噪声
                    noise_imgs = np.random.uniform(-1, 1, size=(self.batch_size, self.img_size))
                    # 优化
                    _ = sess.run(gen_optimizer, feed_dict={fake_imgs: noise_imgs})
                    _ = sess.run(dis_optimizer, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                    step += 1

8.结果

9.代码中需要修改的地方

  1. mnist_model.py 文件33行 def generator_graph(fake_imgs, units_size, out_size, alpha=0.01): 应当改为 def generator_graph(fake_imgs, units_size, out_size, alpha=0.01, reuse=False):

  2. mnist_model.py 文件35行 with tf.variable_scope('generator'): 应当改为 with tf.variable_scope('generator', reuse=reuse):

  3. mnist_model.py 文件149行 gen_logits, gen_outputs = self.generator_graph(sample_imgs, self.units_size, self.img_size) 应当改为 gen_logits, gen_outputs = self.generator_graph(sample_imgs, self.units_size, self.img_size, reuse=tf.AUTO_REUSE)

  4. mnist_model.py 文件143行 model_path = os.getcwd() + os.sep + "mnist.model" 应当改为 model_path = "."

其他细节请参考Github上的README文件。

总结

这里我们简单介绍了对抗神经网络的理论基础,并且基于Github上的已有代码简单分析了使用GAN生成MNIST手写数字图片的过程。从最终的模型结果来看,生成的图像能够将背景与数字区分开,黑色块噪声逐渐消失,但从显示结果来看还是有很多模糊区域的。 原因主要是因为我们这里所构建的生成器和判别器的结构都非常简单,只用到了含有一个隐层的全连接神经网络。如果仔细设计一下网络结构,增加层数,使用卷积神经网络代替全连接神经网络,相信会有更好的效果,大家也可以尝试一下。

参考文章