如何编写 GAN 训练算法和损失函数的代码

生成对抗网络(简称GAN)是一种用于训练生成模型的架构。

该架构由两个模型组成。我们感兴趣的生成器模型,以及一个用于协助训练生成器的判别器模型。最初,生成器和判别器模型都是作为多层感知器(MLP)实现的,但最近,这些模型被实现为深度卷积神经网络。

理解GAN如何训练以及如何理解和实现生成器和判别器模型的损失函数可能具有挑战性。

在本教程中,您将学习如何实现生成对抗网络训练算法和损失函数。

完成本教程后,您将了解:

  • 如何实现生成对抗网络的训练算法。
  • 判别器和生成器的损失函数是如何工作的。
  • 如何在实践中实现判别器和生成器模型的权重更新。

我的新书《Python生成对抗网络》将助您启动项目,其中包括分步教程和所有示例的Python源代码文件。

让我们开始吧。

  • 2020年1月更新:修正了训练算法描述中的一个小错字。
How to Code the Generative Adversarial Network Training Algorithm and Loss Functions

如何编写生成对抗网络训练算法和损失函数的代码
图片由Hilary Charlotte拍摄,保留部分权利。

教程概述

本教程分为三个部分;它们是:

  1. 如何实现GAN训练算法
  2. 理解GAN损失函数
  3. 如何在实践中训练GAN模型

注意:本教程中的代码示例仅为片段,并非独立的运行示例。它们旨在帮助您对算法建立直观理解,并可作为您在自己的项目中实现GAN训练算法的起点。

如何实现GAN训练算法

GAN训练算法涉及同时训练判别器和生成器模型。

该算法总结如下图所示,引自Goodfellow等人于2014年发表的原始论文《生成对抗网络》。

Summary of the Generative Adversarial Network Training Algorithm

生成对抗网络训练算法总结。摘自:《生成对抗网络》。

让我们花点时间来理解和熟悉这个算法。

该算法的外部循环涉及迭代训练架构中的模型的步骤。这个循环的一个周期不是一个epoch:它是由对判别器和生成器模型的特定批次更新组成的一次更新。

一个epoch定义为一次遍历训练数据集,其中训练数据集中的样本以小批量方式用于更新模型权重。例如,一个包含100个样本的训练数据集,以10个样本的小批量大小训练模型,将涉及每个epoch进行10次小批量更新。模型将训练给定数量的epoch,例如500个。

这通常通过调用 `fit()` 函数并指定 epoch 数量和每个 mini-batch 的大小来自动训练模型而对您隐藏。

对于GAN,训练迭代的次数必须根据您的训练数据集大小和批次大小来定义。如果数据集包含100个样本,批次大小为10,训练周期为500个,我们将首先计算每个周期内的批次数量,然后使用该数量和周期总数来计算总训练迭代次数。

例如:

如果数据集包含100个样本,批次大小为10,epoch为500个,则GAN将训练 floor(100 / 10) * 500 = 5,000次总迭代。

接下来,我们可以看到一次训练迭代可能会导致判别器进行多次更新,以及生成器进行一次更新,其中对判别器的更新次数是一个超参数,通常设置为1。

训练过程包括同步SGD。在每一步中,采样两个小批量:一个是从数据集中采样的x值小批量,另一个是从模型对潜在变量的先验中抽取的z值小批量。然后同时进行两个梯度步长……

NIPS 2016 教程:生成对抗网络, 2016。

因此,我们可以用Python伪代码总结训练算法如下:

另一种方法可能涉及枚举训练周期数,并将训练数据集分割成每个周期内的批次。

更新判别器模型涉及几个步骤。

首先,必须从潜在空间中选择一批随机点作为生成器模型的输入,以提供生成或“伪造”样本的基础。然后必须从训练数据集中选择一批样本作为“真实”样本输入给判别器。

接下来,判别器模型必须对真实和伪造样本进行预测,并且判别器的权重必须与这些预测的正确或不正确程度成比例地更新。预测是概率,我们将在下一节中深入探讨预测的性质和最小化的损失函数。现在,我们可以概述这些步骤在实践中具体是怎样的。

我们需要一个生成器模型和一个判别器模型,例如Keras模型。这些可以作为参数提供给训练函数。

接下来,我们必须从潜在空间生成点,然后使用当前形式的生成器模型生成一些假图像。例如:

请注意,潜在维度的尺寸也作为超参数提供给训练算法。

然后我们必须选择一批真实样本,这也将被封装到一个函数中。

判别器模型必须随后对每个生成的图像和真实图像进行预测,并且必须更新权重。

接下来,生成器模型必须更新。

同样,必须选择一批来自潜在空间的随机点并传递给生成器以生成假图像,然后传递给判别器进行分类。

然后可以使用该响应来更新生成器模型的权重。

有趣的是,判别器在每次训练迭代中用两批样本进行更新,而生成器在每次训练迭代中只用一批样本进行更新。

现在我们已经定义了GAN的训练算法,我们需要理解模型权重是如何更新的。这需要理解用于训练GAN的损失函数。

想从零开始开发GAN吗?

立即参加我为期7天的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

理解GAN损失函数

判别器的训练目的是正确分类真实图像和虚假图像。

这是通过最大化真实图像的预测概率对数和虚假图像的倒置概率对数来实现的,并对每个小批量样本进行平均。

回想一下,我们添加对数概率,这与乘以概率相同,只是避免了数值变得过小。因此,我们可以将此损失函数理解为寻求真实图像的概率接近1.0,以及虚假图像的概率接近0.0,然后将其反转为更大的数字。这些值的相加意味着该损失函数的较低平均值会导致判别器更好的性能。

将其转换为最小化问题,如果您熟悉开发用于二元分类的神经网络,那么这就不足为奇了,因为这正是所采用的方法。

这只是在训练带有 Sigmoid 输出的标准二元分类器时最小化的标准交叉熵成本。唯一的区别是分类器在两个小批量数据上进行训练;一个来自数据集,所有示例的标签为 1,另一个来自生成器,所有示例的标签为 0。

NIPS 2016 教程:生成对抗网络, 2016。

生成器模型则更复杂。

GAN算法将生成器模型的损失定义为最小化判别器对伪造图像预测的倒置概率的对数,并对一个小批量进行平均。

这很简单,但根据作者的说法,当生成器性能不佳,而判别器能以高置信度拒绝伪造图像时,这种方法在实践中并不有效。损失函数不再提供生成器可用于调整权重的良好梯度信息,反而会饱和。

在这种情况下,log(1 - D(G(z))) 饱和。与其训练 G 最小化 log(1 - D(G(z))),我们可以训练 G 最大化 log D(G(z))。这个目标函数会导致 G 和 D 动力学的相同不动点,但在学习初期提供更强的梯度。

生成对抗网络,2014年。

相反,作者建议最大化判别器对伪造图像的预测概率的对数。

这种改变是微妙的。

在第一种情况下,生成器的训练目标是最小化判别器判断正确的概率。通过对损失函数的这种改变,生成器的训练目标是最大化判别器判断错误的概率。

在极小极大博弈中,生成器最小化判别器判断正确的对数概率。在此博弈中,生成器最大化判别器判断错误的对数概率。

NIPS 2016 教程:生成对抗网络, 2016。

然后可以反转这个损失函数的符号,得到一个熟悉的用于训练生成器的最小化损失函数。因此,这有时被称为训练GAN的“-log D”技巧。

我们的基线比较是DCGAN,这是一种采用卷积架构的GAN,使用标准GAN过程和-log D技巧进行训练。

Wasserstein GAN,2017年。

现在我们已经了解了GAN的损失函数,接下来我们可以看看如何在实践中更新判别器和生成器模型。

如何在实践中训练GAN模型

GAN损失函数和模型更新的实际实现非常简单。

我们将使用Keras库查看示例。

我们可以通过配置判别器模型来直接实现判别器,使其对真实图像预测为1,对伪造图像预测为0,并最小化交叉熵损失,特别是二元交叉熵损失。

例如,我们的判别器在Keras中的模型定义片段,其输出层和使用适当损失函数进行模型编译可能如下所示。

定义的模型可以针对每个批次的真实和伪造样本进行训练,提供预期结果为1和0的数组。

ones()zeros()NumPy函数可用于创建这些目标标签,而Keras函数train_on_batch()可用于更新每个批次样本的模型。

判别器模型将被训练以预测给定输入图像“真实性”的概率,这可以解释为类别标签,其中类=0表示伪造,类=1表示真实。

生成器被训练以最大化判别器预测生成图像具有高“真实性”的概率。

这通过将生成器通过判别器更新,并为生成的图像指定类别标签1来实现。在此操作中,判别器不会更新,但会提供更新生成器模型权重所需的梯度信息。

例如,如果判别器对一批生成的图像预测的平均概率较低,那么这将导致一个大的误差信号反向传播到生成器中,因为样本的“期望概率”是1.0(真实)。这个大的误差信号反过来会导致生成器发生相对较大的改变,以期在下一批样本中提高其生成伪造样本的能力。

这可以在 Keras 中通过创建一个复合模型来实现,该模型将生成器和判别器模型结合起来,允许生成器的输出图像直接流入判别器,从而允许判别器预测概率的误差信号反向流经生成器模型的权重。

例如:

然后可以使用伪造图像和真实类标签更新复合模型。

这完成了我们对GAN训练算法、损失函数以及判别器和生成器模型权重更新细节的讲解。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

论文

文章

总结

在本教程中,您学习了如何实现生成对抗网络训练算法和损失函数。

具体来说,你学到了:

  • 如何实现生成对抗网络的训练算法。
  • 判别器和生成器的损失函数是如何工作的。
  • 如何在实践中实现判别器和生成器模型的权重更新。

你有什么问题吗?
在下面的评论中提出你的问题,我会尽力回答。

立即开发生成对抗网络!

Generative Adversarial Networks with Python

在几分钟内开发您的GAN模型

...只需几行python代码

在我的新电子书中探索如何实现
使用 Python 构建生成对抗网络

它提供了关于以下内容的自学教程端到端项目
DCGAN条件GAN图像翻译Pix2PixCycleGAN
以及更多...

最终将GAN模型引入您的视觉项目

跳过学术理论。只看结果。

查看内容

如何编写GAN训练算法和损失函数代码的22条回复

  1. Kate 2019年7月12日上午8:18 #

    嗨,Jason,

    如果我学习了Python时间序列书籍以及《从零开始的机器学习算法》。你认为我能理解《今天开发生成对抗网络》这本书,还是应该先从深度学习书籍开始?

    谢谢!

    凯特

  2. Olufikayo Bolu 2019年8月10日下午3:52 #

    嗨 Jason,感谢您的贡献和解释。我正在考虑将GAN用于非图像数据。鉴于此示例是使用图像数据实现的,您会如何建议我开始实施?另外,您可以提供您已实施示例的完整示例代码吗?提前感谢

  3. patti 2020年2月5日上午2:53 #

    您好,Jason,非常感谢您的精彩教程。

    我对生成器损失函数有点困惑,那么它实际上在 train_gan 代码部分使用了二元交叉熵,对吗?

  4. The Coder heist 2020年5月19日上午5:19 #

    嗨 Jason,感谢您的贡献和解释。我正在尝试为 CNN LSTM 时空帧预测算法开发 GAN。请告诉我它的可行性和一些参考文献。

  5. kevin 2020年10月4日下午2:14 #

    为什么在实现gan训练时需要将discriminator.trainable = False?

    • Jason Brownlee 2020年10月4日下午2:59 #

      这样在更新生成器时,我们不会更新判别器。

      • Zeinab 2021年5月10日下午10:32 #

        你好
        您有关于使用GAN将文本转换为图像的教程吗?

  6. Cara Evangeline 2021年3月9日上午12:23 #

    嗨 Jason,感谢这个很棒的博客!
    我有一个关于GAN时间序列预测的问题。对于判别器,我们可以将损失函数声明为binary_crossentropy,那么生成器网络会是怎样的情况呢?
    生成器将预测下一个时间步,那么损失函数会是MAE还是binary_entropy?

    • Jason Brownlee 2021年3月9日上午5:21 #

      抱歉,我对时间序列的GAN不了解,只了解图像数据。

  7. MIMI 2021年5月19日上午1:15 #

    你好
    为什么有人在GAN中使用HOG?
    如果有任何方式可以交流,我有很多问题想问您。

  8. Irshad 2021年5月19日下午10:34 #

    谢谢你,Jason。很棒的帖子。

    考虑一个需要生成三维数据的例子。根据这篇文章的理解,判别器模型的损失被用来改变生成器模型的权重。但是生成器如何知道它必须在哪一个维度上改变权重?它会对所有维度进行相同的权重改变吗?

    • Jason Brownlee 2021年5月20日上午5:47 #

      通常GAN用于生成图像数据,抱歉,我对生成其他数据类型不了解。

  9. prachi 2021年11月27日上午1:50 #

    先生,我正在尝试使用 GAN 创建用于风景图像文本检测的合成数据集。
    你能提供将使用的技术吗?

    • Adrian Tam
      Adrian Tam 2021年11月29日上午8:39 #

      你尝试过什么?合成数据的方法有很多种,因此,各种技术都与之相关。如果你首先想到一个合成数据的想法,那么讨论这些技术会更容易。

发表回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。