如何使用 Keras 开发最小二乘生成对抗网络 (LSGAN)

最小二乘生成对抗网络,简称 LSGAN,是 GAN 架构的扩展,解决了梯度消失和损失饱和问题。

其动机是为了给生成器提供一个信号,关于那些远离判别器模型决策边界的假样本,这些样本被判别器分类为真或假。生成的图像离决策边界越远,提供给生成器的误差信号就越大,从而鼓励生成更逼真的图像。

LSGAN 可以通过对判别器层的输出层进行微小改动,并采用最小二乘(或 L2)损失函数来实现。

在本教程中,您将了解如何开发最小二乘生成对抗网络。

完成本教程后,您将了解:

  • LSGAN 解决了深度卷积 GAN 的梯度消失和损失饱和问题。
  • LSGAN 可以通过均方误差或 L2 损失函数来实现判别器模型。
  • 如何为 MNIST 数据集实现 LSGAN 模型来生成手写数字。

立即开始您的项目,阅读我的新书 《Python 生成对抗网络》,其中包含分步教程以及所有示例的Python 源代码文件。

让我们开始吧。

  • 2021 年 1 月更新:已更新,以便层冻结与批处理归一化一起使用。
How to Develop a Least Squares Generative Adversarial Network (LSGAN) for Image Generation

如何开发用于图像生成的最小二乘生成对抗网络(LSGAN)
照片来源:alyssa BLACK,部分权利保留。

教程概述

本教程分为三个部分;它们是:

  1. 什么是最小二乘 GAN
  2. 如何为 MNIST 手写数字开发 LSGAN
  3. 如何使用 LSGAN 生成图像

什么是最小二乘 GAN

标准的生成对抗网络,简称 GAN,是训练无监督生成器的一种有效架构。

该架构涉及训练一个判别器模型来区分真实图像(来自数据集)和假图像(生成的图像),然后利用判别器来训练生成器模型。生成器会以一种鼓励其生成更能欺骗判别器的图像的方式进行更新。

判别器是一个二元分类器,使用二元交叉熵损失函数进行训练。此损失函数的一个局限性在于,它主要关注预测是否正确,而不太关注预测的正确或错误程度。

“……当我们使用假样本通过让判别器相信它们是真实数据来更新生成器时,几乎不会产生误差,因为它们位于决策边界的正确一侧,即真实数据一侧。”

最小二乘生成对抗网络,2016。

这可以用二维概念来理解,即一条线或决策边界将代表真实和假图像的点分开。判别器负责设计决策边界以最佳地分离真实和假图像,而生成器负责创建看起来像真实点的点,从而混淆判别器。

选择交叉熵损失意味着远离边界的点是正确或错误的,但为生成器如何生成更好的图像提供的梯度信息非常少。

这些远离决策边界的生成图像的微小梯度被称为梯度消失问题或损失饱和。损失函数无法提供关于如何最佳更新模型的强信号。

最小二乘生成对抗网络,简称 LSGAN,是 Xudong Mao 等人在其 2016 年题为“最小二乘生成对抗网络”的论文中提出的 GAN 架构的扩展。LSGAN 是对 GAN 架构的修改,将判别器的损失函数从二元交叉熵更改为最小二乘损失。

这一改变的动机在于,最小二乘损失将根据生成图像与决策边界的距离来惩罚它们。这将为生成距离现有数据非常不同或非常远的图像提供一个强梯度信号,并解决损失饱和问题。

“……常规 GAN 的目标函数最小化会遇到梯度消失问题,这使得更新生成器变得困难。LSGAN 可以缓解这个问题,因为 LSGAN 根据样本到决策边界的距离进行惩罚,从而生成更多梯度来更新生成器。”

最小二乘生成对抗网络,2016。

这可以用下面论文中的图来理解:左图显示了sigmoid决策边界(蓝色)和远离决策边界的生成假点(粉色),右图显示了最小二乘决策边界(红色)以及远离边界的假点(粉色),这些假点会获得一个使其更接近边界的梯度。

Plot of the Sigmoid Decision Boundary vs the Least Squared Decision Boundary for Updating the Generator

sigmoid 决策边界与最小二乘决策边界更新生成器的对比图。
来源:最小二乘生成对抗网络。

除了避免损失饱和之外,LSGAN 在训练过程也更稳定,并且比传统的深度卷积 GAN 生成更高质量和更大的图像。

首先,LSGAN 能够生成比常规 GAN 更高质量的图像。其次,LSGAN 在学习过程中表现更稳定。

最小二乘生成对抗网络,2016。

LSGAN 可以通过使用真实图像的目标值 1.0 和假图像的目标值 0.0 来实现,并使用均方误差(MSE)损失函数(例如 L2 损失)来优化模型。判别器模型的输出层必须是线性激活函数。

作者提出了一个受 VGG 模型架构启发的生成器和判别器模型架构,并在生成器模型中使用交替的上采样和普通卷积层,如下图左侧所示。

Summary of the Generator (left) and Discriminator (right) Model Architectures used in LSGAN Experiments

LSGAN 实验中使用的生成器(左)和判别器(右)模型架构摘要。
来源:最小二乘生成对抗网络。

想从零开始开发GAN吗?

立即参加我为期7天的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

如何为 MNIST 手写数字开发 LSGAN

在本节中,我们将为 MNIST 手写数字数据集 开发一个 LSGAN。

第一步是定义模型。

判别器和生成器都将基于深度卷积 GAN,即 DCGAN 架构。这涉及使用卷积-批标准化-激活层块,使用 2x2 步长进行下采样,并使用转置卷积层进行上采样。判别器中使用 LeakyReLU 激活层,生成器中使用 ReLU 激活层

判别器期望灰度输入图像,形状为 28x28,这是 MNIST 数据集中图像的形状,输出层是具有线性激活函数的单个节点。根据 LSGAN,模型使用均方误差(MSE)损失函数进行优化。下面的 `define_discriminator()` 函数定义了判别器模型。

生成器模型以潜在空间中的一个点作为输入,并通过输出层的 tanh 激活函数输出一个形状为 28x28 像素的灰度图像,其中像素值在 [-1,1] 范围内。

下面的 `define_generator()` 函数定义了生成器模型。此模型未编译,因为它不是独立训练的。

生成器模型通过判别器模型进行更新。这是通过创建一个复合模型来实现的,该模型将生成器堆叠在判别器之上,以便误差信号可以流经判别器反向传播到生成器。

在复合模型中使用时,判别器的权重被标记为不可训练。通过复合模型进行的更新包括使用生成器通过提供潜在空间的随机点作为输入来创建新图像。生成的图像传递给判别器,判别器会将其分类为真实或假。权重会像生成的图像是真实的一样进行更新(例如,目标值为 1.0),从而使生成器朝着生成更逼真的图像进行更新。

下面的 `define_gan()` 函数定义并编译了用于通过判别器更新生成器模型的复合模型,同样根据 LSGAN 使用均方误差进行优化。

接下来,我们可以定义一个函数来加载 MNIST 手写数字数据集,并将像素值缩放到 [-1,1] 范围,以匹配生成器模型输出的图像。

仅使用 MNIST 数据集的训练部分,其中包含 60,000 张数字零到九的居中灰度图像。

然后,我们可以定义一个函数来从训练数据集中检索一批随机选择的图像。

真实图像将返回相应的目标值,用于判别器模型,例如 y=1.0,表示它们是真实的。

接下来,我们可以开发相应的生成器函数。

首先,一个用于生成潜在空间随机点的函数,用作通过生成器模型生成图像的输入。

接下来,一个将使用生成器模型生成一批假图像来更新判别器模型的函数,以及一个表示图像为假的(y=0)目标值。

我们需要在训练过程中定期使用生成器来生成我们可以主观检查并作为选择最终生成器模型基础的图像。

下面的 `summarize_performance()` 函数可以在训练期间调用,以生成并保存图像图和保存生成器模型。图像使用反向灰度颜色映射绘制,以使数字为黑色背景上的黑色。

我们还对训练过程中损失的行为感兴趣。

因此,我们可以将损失记录在列表中,跨越每个训练迭代,然后创建并保存模型学习动态的折线图。`plot_history()` 函数实现了创建和保存学习曲线图的功能。

最后,我们可以通过 `train()` 函数定义主训练循环。

该函数将定义的模型和数据集作为参数,并将训练轮数和批次大小参数化为默认函数参数。

每个训练循环涉及首先生成半批次的真实和假样本,并使用它们来创建一批用于判别器的权重更新。接下来,通过复合模型更新生成器,提供真实(y=1)目标作为模型的预期输出。

每个训练迭代都会报告损失,并在每个 epoch 结束时以生成的图像图的形式总结模型性能。学习曲线图在运行结束时创建并保存。

将所有这些内容结合起来,下面列出了在MNIST手写数字数据集上训练LSGAN的完整代码示例。

注意:该示例可以在CPU上运行,但可能需要一些时间,建议在GPU硬件上运行。

注意:您的结果可能因算法或评估程序的随机性、数值精度的差异而有所不同。建议多次运行该示例并比较平均结果。

运行该示例将报告判别器在真实(d1)和虚假(d2)样本上的损失,以及生成器通过判别器在作为真实样本生成的样本上的损失(g)。

这些分数在每个训练运行结束时打印,并且预计在整个训练过程中保持较小的值。长时间出现零值可能表明出现故障模式,应重新启动训练过程。

每个训练周期结束后都会生成生成的图像图。

运行开始时生成的图像很粗糙。

Example of 100 LSGAN Generated Handwritten Digits after 1 Training Epoch

一个训练周期的LSGAN生成手写数字示例(100个)

经过几个训练周期后,生成的图像开始变得清晰逼真。

请记住:更多的训练周期可能与生成器输出更高质量图像的生成器相对应,也可能不对应。查看生成的图,并选择一个具有最佳图像质量的最终模型。

Example of 100 LSGAN Generated Handwritten Digits After 20 Training Epochs

20个训练周期后的LSGAN生成手写数字示例(100个)

在训练运行结束时,将为判别器和生成器创建学习曲线图。

在这种情况下,我们可以看到训练在整个运行过程中保持相对稳定,但观察到一些非常大的峰值,这会影响绘图的比例。

Plot of Learning Curves for the Generator and Discriminator in the LSGAN During Training.

LSGAN训练过程中生成器和判别器的学习曲线图。

如何使用 LSGAN 生成图像

我们可以使用保存的生成器模型来按需创建新图像。

这可以通过首先根据图像质量选择最终模型,然后加载它,并提供来自潜在空间的新点作为输入来实现,以便从领域生成新的可信图像。

在这种情况下,我们将使用在20个周期(即18,740次(60K/64或937个批次/周期 * 20个周期)训练迭代)后保存的模型。

运行该示例将生成一个10x10(即100个)新的、可信的手写数字图。

Plot of 100 LSGAN Generated Plausible Handwritten Digits

LSGAN生成的可信手写数字图(100个)

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

论文

API

文章

总结

在本教程中,您学习了如何开发一个最小二乘生成对抗网络。

具体来说,你学到了:

  • LSGAN 解决了深度卷积 GAN 的梯度消失和损失饱和问题。
  • LSGAN 可以通过均方误差或 L2 损失函数来实现判别器模型。
  • 如何为 MNIST 数据集实现 LSGAN 模型来生成手写数字。

你有什么问题吗?
在下面的评论中提出你的问题,我会尽力回答。

立即开发生成对抗网络!

Generative Adversarial Networks with Python

在几分钟内开发您的GAN模型

...只需几行python代码

在我的新电子书中探索如何实现
使用 Python 构建生成对抗网络

它提供了关于以下内容的自学教程端到端项目
DCGAN条件GAN图像翻译Pix2PixCycleGAN
以及更多...

最终将GAN模型引入您的视觉项目

跳过学术理论。只看结果。

查看内容

如何开发最小二乘生成对抗网络(LSGAN)在Keras中的20条回复

  1. Sufyan Danish 2019年7月28日下午10:07 #

    首先,请允许我感谢您付出的辛勤努力,我是一名DL和CV领域的新学习者。我从这个网站上学到了很多关于深度学习(DL)和计算机视觉(CV)的知识。请先生,发布一个关于视频超分辨率或自动视频增强的主题,关于我们如何进行自动视频增强。或者,如果您知道任何涵盖这些内容以及使用GEN或其他深度学习方法的代码的网站,请与我们分享。

    • Jason Brownlee 2019年7月29日上午6:15 #

      谢谢!

      抱歉,我目前没有超分辨率的例子。也许将来会有。

  2. Anne Bierhoff 2019年12月15日下午10:16 #

    嗨,Jason,

    目前,我正在尝试设计一个“生成对抗网络”来将系统的仿真模型映射到一个神经网络,以测试是否可以用相似的精度实现更好的执行时间。

    仿真模型接收其环境的实值测量数据,并从中生成实值输出。

    环境的实值测量数据是有限的,因此我考虑使用生成对抗网络来训练预测器。

    目标是让生成器生成逼真的输入。然后将这些输入馈送到仿真模型以生成输出。然后将输入和输出用于训练预测器。

    总体目标是获得一个高质量且泛化能力强的预测器。

    生成器的目标一方面是生成尽可能逼真的输入,但另一方面也是生成预测器尚未充分训练的输入。

    实现这一目标的最佳方法是什么?

  3. abc 2020年2月25日上午7:13 #

    对于tensorflow 2.0+,请将导入中的keras替换为tensorflow.keras。

    • Jason Brownlee 2020年2月25日上午7:54 #

      感谢您的建议。

      该示例使用了独立的Keras。

  4. Joshua 2020年6月19日上午12:33 #

    你好,

    判别器中的最后一层为什么使用线性激活?既然我们计算的是0到1之间的值的MSE,使用Sigmoid而不是更合理吗?

    Joshua

    • Jason Brownlee 2020年6月19日上午6:17 #

      是的,我们设计上使用线性输出来匹配MSE损失。更多细节可以参考链接的论文。

  5. Manohar 2020年9月23日下午3:15 #

    感谢您的精彩解释。
    我在Google Colab中运行了相同的代码,没有任何修改。生成的图像很糟糕,判别器和生成器的损失都降到了0。
    我不知道模型出了什么问题。你能帮忙吗?

    • Jason Brownlee 2020年9月24日上午6:09 #

      也许可以尝试在您的工作站或AWS EC2实例上运行。

  6. Aria 2020年12月14日上午3:10 #

    你好,Jason。
    非常感谢您的帖子,非常感谢。
    我使用了您提供的代码,没有做任何修改。我在Spyder(Python 3.8.3)中运行。
    然而,我的结果与您展示的结果相差甚远。20个周期后,我的结果比您1个周期的结果还要差。您有什么建议来解释我的结果如此之差,而使用的是与您相同的代码吗?
    诚挚地,Aria

  7. Stefano Sartori 2021年3月30日上午1:53 #

    你好 Jason,

    首先,非常感谢您的深度教程,它们非常有用!

    我有一个关于您的教程中DCGAN和LSGAN训练之间差异的问题

    DCGAN:真实和虚假样本堆叠在一起,然后对堆叠后的样本调用train_on_batch()

    LSGAN:真实和虚假样本分别调用train_on_batch()

    这种差异的原因是什么?

    考虑到训练算法和模型损失函数的性质,我不期望这种差异。

    任何建议都将不胜感激。

    非常感谢&此致

    Stefano

    • Jason Brownlee 2021年3月30日上午6:06 #

      有时将样本批处理在一起会很好,有时将它们分开会更好。

      这可能取决于数据集和模型类型。我建议尝试这两种方法,并找出最适合您的方法。

  8. Hala 2021年5月8日下午4:45 #

    这与噪声损失相似吗?

    • Jason Brownlee 2021年5月9日上午5:53 #

      抱歉,我不明白您的问题。也许您可以详细说明或重新表述一下?

  9. Saksham 2022年1月25日下午11:12 #

    嗨,Jason,
    我尝试了上述方法,但在训练结束时结果非常糟糕。判别器损失很快饱和到零,所以这似乎是收敛失败的情况。
    当我将metrics = [“accuracy”]添加到判别器的编译中时,它起作用了。但我不太确定这如何解决了问题。

    • James Carmichael 2022年2月4日上午10:37 #

      您可能正在处理一个回归问题,并实现了零预测误差。

      或者,你可能正在处理分类问题并实现 100% 的准确率。

      这很不寻常,原因有很多,包括:

      你不小心在训练集上评估了模型性能。
      你的保留数据集(训练集或验证集)太小或不具代表性。
      你的代码中引入了一个错误,它正在做一些与你预期不同的事情。
      你的预测问题很容易或微不足道,可能不需要机器学习。
      最常见的原因是你的保留数据集太小或不代表更广泛的问题。

      可以通过以下方法解决:

      使用 k 折交叉验证来估计模型性能,而不是训练/测试拆分。
      收集更多数据。
      使用不同的数据拆分进行训练和测试,例如 50/50。

  10. Hsiao 2022年3月21日下午7:46 #

    嗨,Jason,

    非常感谢您在此网站上展示的许多DL实现代码!

    您是否可以展示如何在Keras中实现基于能量的GAN(EBGAN)和边界平衡GAN(BEGAN)?我对这些GAN的损失函数感到困惑,这些GAN是否应该使用Lambda层来定义自定义损失函数?

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。