如何在生成对抗网络中实现 Wasserstein 损失

Wasserstein 生成对抗网络(简称 Wasserstein GAN)是生成对抗网络的一种扩展,它既提高了模型训练时的稳定性,又提供了一个与生成图像质量相关的损失函数。

它是 GAN 模型的一个重要扩展,需要从判别器预测生成图像“真实”概率的概念转向一个评估给定图像“真实性”的评论家模型。

这种概念上的转变在数学上通过使用地球移动距离(或 Wasserstein 距离)来训练 GAN 得到解释,该距离衡量训练数据集中观察到的数据分布与生成样本中观察到的分布之间的距离。

在这篇文章中,您将学习如何在生成对抗网络中实现 Wasserstein 损失。

阅读本文后,你将了解:

  • WGAN 中从判别器预测概率到评论家预测分数的概念转变。
  • WGAN 的实现细节是对标准深度卷积 GAN 的微小修改。
  • Wasserstein 损失函数的直观理解以及如何从头开始实现它。

通过我的新书《Python 生成对抗网络》**启动您的项目**,其中包括**分步教程**和所有示例的 **Python 源代码**文件。

让我们开始吧。

How to Implement Wasserstein Loss for Generative Adversarial Networks

如何在生成对抗网络中实现 Wasserstein 损失
照片由 Brandon Levinger 拍摄,保留部分权利。

概述

本教程分为五个部分;它们是:

  1. GAN 稳定性与判别器
  2. 什么是 Wasserstein GAN?
  3. Wasserstein GAN 的实现细节
  4. 如何实现 Wasserstein 损失
  5. 预期标签的常见混淆点

GAN 稳定性与判别器

生成对抗网络(GAN)训练起来很有挑战性。

判别器模型必须将给定的输入图像分类为真实(来自数据集)或伪造(生成的),而生成器模型必须生成新的且合理的图像。

GAN 难以训练的原因是其架构涉及生成器和判别器模型在零和博弈中同时训练。稳定的训练需要找到并维持两个模型能力之间的平衡。

判别器模型是一个神经网络,它学习一个二元分类问题,在输出层使用 sigmoid 激活函数,并使用二元交叉熵损失函数进行拟合。因此,该模型预测给定输入是真实(或伪造,即 1 减去预测值)的概率,其值介于 0 和 1 之间。

损失函数的作用是根据预测概率分布与给定图像的预期概率分布之间的差异程度来惩罚模型。这为通过判别器和生成器反向传播误差提供了基础,以便在下一批次中表现更好。

WGAN 放松了判别器在训练 GAN 时的作用,并提出了评论家的替代方案。

什么是 Wasserstein GAN?

Wasserstein GAN,简称 WGAN,由 Martin Arjovsky 等人在其 2017 年论文《Wasserstein GAN》中提出。

它是 GAN 的一种扩展,旨在寻找一种替代的训练生成器模型的方法,以更好地近似给定训练数据集中观察到的数据分布。

WGAN 不再使用判别器来分类或预测生成图像是真实还是虚假的概率,而是将判别器模型替换为评论家,该评论家评估给定图像的真实性或虚假性。

这一改变的动机是数学论证,即生成器的训练应旨在最小化训练数据集中观察到的数据分布与生成样本中观察到的分布之间的距离。该论证对比了不同的分布距离度量,例如 Kullback-Leibler (KL) 散度、Jensen-Shannon (JS) 散度以及被称为 Wasserstein 距离的 Earth-Mover (EM) 距离。

这些距离之间最根本的区别在于它们对概率分布序列收敛的影响。

Wasserstein GAN,2017年。

他们证明了可以训练一个评论家神经网络来近似 Wasserstein 距离,并反过来用于有效地训练生成器模型。

... 我们定义了一种名为 Wasserstein-GAN 的 GAN 形式,它最小化了 EM 距离的合理且有效的近似值,并且我们从理论上表明相应的优化问题是合理的。

Wasserstein GAN,2017年。

重要的是,Wasserstein 距离具有连续和可微的特性,并且即使在评论家训练良好之后,也能持续提供线性梯度。

EM 距离在几乎所有地方都是连续且可微的这一事实意味着我们可以(也应该)将评论家训练到最优。 […] 我们训练评论家越多,我们获得的 Wasserstein 梯度就越可靠,这实际上是有用的,因为 Wasserstein 在几乎所有地方都是可微的。

Wasserstein GAN,2017年。

这与判别器模型不同,判别器模型一旦训练好,可能无法为更新生成器模型提供有用的梯度信息。

判别器很快学会区分假和真,正如预期的那样,它没有提供可靠的梯度信息。然而,评论家不会饱和,并且会收敛到一个线性函数,该函数在所有地方都提供非常清晰的梯度。

Wasserstein GAN,2017年。

WGAN 的优点是训练过程更稳定,对模型架构和超参数配置的选择不那么敏感。

... 训练 WGANs 不需要小心地平衡判别器和生成器的训练,也不需要仔细设计网络架构。GANs 中典型的模式崩溃现象也大大减少了。

Wasserstein GAN,2017年。

也许最重要的是,判别器的损失似乎与生成器创建的图像质量有关。

具体来说,评论家评估生成图像时的损失越低,生成图像的预期质量就越高。这很重要,因为与其他 GAN 寻求在两个模型之间找到平衡以实现稳定性不同,WGAN 寻求收敛,降低生成器损失。

据我们所知,这是 GAN 文献中首次展示这种性质,即 GAN 的损失表现出收敛性质。这种性质在对抗网络研究中非常有用,因为无需查看生成的样本来找出失败模式并获取哪些模型优于其他模型的信息。

Wasserstein GAN,2017年。

想从零开始开发GAN吗?

立即参加我为期7天的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

Wasserstein GAN 的实现细节

尽管 WGAN 的理论基础很密集,但 WGAN 的实现只需对标准深度卷积 GAN (DCGAN) 进行一些微小更改。

这些更改如下:

  • 在评论家模型的输出层使用线性激活函数(而不是 Sigmoid)。
  • 使用 Wasserstein 损失来训练评论家和生成器模型,以促进真实图像和生成图像分数之间的更大差异。
  • 每次小批量更新后将评论家模型权重限制在特定范围(例如 [-0.01, 0.01])。

为了使参数 w 位于紧凑空间中,我们可以做的简单方法是在每次梯度更新后将权重钳制到一个固定框(例如 W = [−0.01, 0.01]l )中。

Wasserstein GAN,2017年。

  • 每次迭代更新评论家模型的次数多于生成器(例如 5 次)。
  • 使用 RMSProp 版本的梯度下降,学习率小且没有动量(例如 0.00005)。

... 我们报告说,当使用基于动量的优化器(例如 Adam)时,WGAN 训练有时会变得不稳定 […] 因此,我们切换到 RMSProp …

Wasserstein GAN,2017年。

下图总结了 WGAN 的主要训练循环,取自论文。请注意模型中使用的推荐超参数列表。

Wasserstein 生成对抗网络的算法。
摘自:《Wasserstein GAN》。

如何实现 Wasserstein 损失

Wasserstein 损失函数旨在增加真实图像和生成图像分数之间的差距。

我们可以将论文中描述的函数总结如下:

  • 评论家损失 = [真实图像的平均评论家分数] – [虚假图像的平均评论家分数]
  • 生成器损失 = -[虚假图像的平均评论家分数]

其中平均分数是在一个迷你批次样本中计算的。

这正是 PyTorch 和 TensorFlow 等基于图的深度学习框架中损失的实现方式。

一旦我们回想起随机梯度下降旨在最小化损失,这些计算就很容易理解。

对于生成器而言,评论家给出更高的分数将导致生成器的损失更小,从而鼓励评论家对虚假图像输出更高的分数。例如,平均分数 10 变为 -10,平均分数 50 变为 -50(更小),依此类推。

对于评论家而言,真实图像的得分越高,评论家的最终损失就越大,从而惩罚模型。这鼓励评论家对真实图像输出更小的得分。例如,真实图像的平均得分 20 和虚假图像的平均得分 50 会导致 -30 的损失;真实图像的平均得分 10 和虚假图像的平均得分 50 会导致 -40 的损失(更好),依此类推。

在这种情况下,损失的符号无关紧要,只要真实图像的损失是一个小数字,而虚假图像的损失是一个大数字即可。Wasserstein 损失鼓励评论家将这些数字分开。

我们也可以反转情况,鼓励评论家对真实图像输出高分,对虚假图像输出低分,从而达到同样的效果。有些实现就是这样做的。

在 Keras 深度学习库(以及其他一些库)中,我们无法直接实现论文中描述的以及 PyTorch 和 TensorFlow 中实现的 Wasserstein 损失函数。相反,我们可以通过不让评论家的损失计算依赖于真实和虚假图像的损失计算来达到相同的效果。

一个很好的思考方式是真实图像的负分数和虚假图像的正分数,尽管这种训练过程中学习到的正/负分数划分不是必需的;只需更大和更小就足够了。

  • 评论家分数低(例如 <0):真实 – 评论家分数高(例如 >0):虚假

我们可以将假图像的平均预测分数乘以-1,这样更大的平均值会变得更小,并且梯度方向是正确的,即最小化损失。例如,三个批次假图像的平均分数为 [0.5, 0.8, 1.0],在计算权重更新时将变为 [-0.5, -0.8, -1.0]。

  • 虚假图像的损失 = -1 * 平均评论家分数

真实分数的情况下不需要改变,因为我们希望鼓励真实图像的平均分数更小。

  • 真实图像的损失 = 平均评论家分数

这可以通过为虚假图像分配 -1 的预期结果目标,为真实图像分配 1 的预期结果目标,并将损失函数实现为预期标签乘以平均分数来实现。-1 标签将乘以虚假图像的平均分数,并鼓励更大的预测平均值;+1 标签将乘以真实图像的平均分数,并且没有效果,鼓励更小的预测平均值。

  • Wasserstein 损失 = 标签 * 平均评论家分数

  • Wasserstein 损失(真实图像)= 1 * 平均预测分数
  • Wasserstein 损失(虚假图像)= -1 * 平均预测分数

我们可以在 Keras 中通过分别给虚假图像和真实图像分配 -1 和 1 的预期标签来实现这一点。可以使用相反的标签来达到相同的效果,例如,-1 用于真实图像,+1 用于虚假图像,以鼓励虚假图像的小分数和真实图像的大分数。一些开发者以这种替代方式实现 WGAN,这同样是正确的。

损失函数可以通过将每个样本的预期标签乘以预测分数(逐元素),然后计算均值来实现。

上述函数是实现损失函数的优雅方式;另一种不太优雅但可能更直观的实现方式如下:

在 Keras 中,可以使用 Keras 后端 API 实现均值函数,以确保在提供的张量中对样本进行均值计算;例如:

现在我们知道如何在 Keras 中实现 Wasserstein 损失函数,让我们澄清一个常见的误解。

预期标签的常见混淆点

回想一下,我们对虚假图像使用 -1 的预期标签,对真实图像使用 +1 的预期标签。

一个常见的混淆点是,一个完美的评论家模型会为每个虚假图像输出 -1,为每个真实图像输出 +1。

这是不正确的。

再次回想一下,我们正在使用随机梯度下降来寻找评论家(和生成器)模型中最小化损失函数的一组权重。

我们已经确定,我们希望评论家模型对虚假图像平均输出更高的分数,对真实图像平均输出更低的分数。然后我们设计了一个损失函数来鼓励这种结果。

这是用于训练神经网络模型的损失函数的关键点。它们鼓励所需的模型行为,并且它们不必通过提供预期的结果来实现这一点。在这种情况下,我们定义了 Wasserstein 损失函数来解释评论家模型预测的平均分数,并使用真实和虚假情况的标签来帮助解释。

那么在 Wasserstein 损失下,真实和虚假图像的良好损失是什么?

Wasserstein 不是一个绝对的、可比较的损失,不能用于比较不同的 GAN 模型。相反,它是相对的,取决于您的模型配置和数据集。重要的是,它对于给定的评论家模型是一致的,并且生成器的收敛(更好的损失)确实与更好的生成图像质量相关。

它可能是真实图像的负分数和虚假图像的正分数,但这不是必需的。所有分数都可以是正的,或者所有分数都可以是负的。

损失函数只鼓励虚假图像和真实图像之间的分数分离为更大和更小,而不一定是正数和负数。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

论文

文章

总结

在这篇文章中,您学习了如何在生成对抗网络中实现 Wasserstein 损失。

具体来说,你学到了:

  • WGAN 中从判别器预测概率到评论家预测分数的概念转变。
  • WGAN 的实现细节是对标准深度卷积 GAN 的微小修改。
  • Wasserstein 损失函数的直观理解以及如何从头开始实现它。

你有什么问题吗?
在下面的评论中提出你的问题,我会尽力回答。

立即开发生成对抗网络!

Generative Adversarial Networks with Python

在几分钟内开发您的GAN模型

...只需几行python代码

在我的新电子书中探索如何实现
使用 Python 构建生成对抗网络

它提供了关于以下内容的自学教程端到端项目
DCGAN条件GAN图像翻译Pix2PixCycleGAN
以及更多...

最终将GAN模型引入您的视觉项目

跳过学术理论。只看结果。

查看内容

对《如何在生成对抗网络中实现 Wasserstein 损失》的 38 条回复

  1. Oleg 2019 年 7 月 17 日上午 3:40 #

    你是否有基于例如 https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py 的 WGAN 完整实现示例?

  2. Joseph 2019 年 7 月 19 日上午 5:30 #

    我正试图理解…我们是如何将 -1 用于真实和 -1 用于虚假计算的。为什么这不能更好地定义。从数学角度来说。

  3. Vincent 2020 年 1 月 31 日上午 6:32 #

    如果我没理解错的话,论文在算法 1 的第 6 行使用的是梯度上升而不是下降。他们试图最大化评论家损失并最小化生成器损失。

    • Jason Brownlee 2020 年 1 月 31 日上午 7:59 #

      是的,梯度下降。

    • François 2020 年 10 月 1 日上午 1:35 #

      你说得对,Vincent:判别器是“+alpha”(上升),生成器是“-alpha”(下降)。

      对抗训练就是这样实现的:两个损失函数中都有“-fw(fake)”(符号相同),但梯度更新的方向是相反的。

      所以这部分只涉及生成器
      > 一旦我们回想起随机梯度下降旨在最小化损失,这些计算就很容易理解。

      通过查看方程
      – 判别器希望最大化 -fw(fake) 最小化 fw(fake)
      – 生成器希望最小化 -fw(fake) = 最大化 fw(fake)
      因此 fw 似乎是图像真实性的分数:分数越大意味着“越真实”。

      所以对生成器的解释是正确的
      > 对于生成器而言,评论家给出更高的分数将导致生成器的损失更小,从而鼓励评论家对虚假图像输出更高的分数。例如,平均分数 10 变为 -10,平均分数 50 变为 -50(更小),依此类推。

      尽管我会这样重新表述,以使其更清晰:
      > 对于生成器而言,评论家给出更高的分数将导致生成器的损失更小,从而鼓励生成器合成具有高分数(意味着逼真图像)的图像。

      对判别器的解释是反过来的
      > 对于评论家而言,真实图像的得分越高,评论家的最终损失就越大,***从而惩罚模型。这鼓励评论家对真实图像输出更小的得分。*** 例如,真实图像的平均得分 20 和虚假图像的平均得分 50 会导致 -30 的损失;真实图像的平均得分 10 和虚假图像的平均得分 50 会导致 -40 的损失,***这更好***,依此类推。在这种情况下,损失的符号无关紧要,只要***真实图像的损失是一个小数字,而虚假图像的损失是一个大数字***即可。Wasserstein 损失鼓励评论家将这些数字分开。

      我希望我表达清楚了……

      • Jason Brownlee 2020 年 10 月 1 日上午 6:30 #

        感谢分享!

      • Vincent Roca 2020 年 11 月 19 日晚上 8:52 #

        感谢您的评论。没有它,我可能会思考生成器和评论家之间没有对抗关系很多个小时。

      • Andreas 2021 年 9 月 14 日下午 4:13 #

        如果更真实的图像得分更高,那么为什么真实图像得分是 20 而虚假图像得分是 50 呢?

  4. Hashem Hashemi 2020 年 6 月 17 日下午 12:12 #

    所以 WGAN 的全部内容可以归结为 (a) 将目标设置为 -1/+1 而不是 0/1,以及 (b) 裁剪判别器权重?为什么很多这些机器学习想法感觉像是作者做了大量实验,发现了一些碰巧效果更好的调整,然后用晦涩的数学来证明它,通常以德国科学家的名字命名?我的意思是整个地球移动解释听起来有点模糊。这如何帮助避免模式崩溃——这似乎是由于生成器对潜在空间变得不敏感,并依赖反向传播来打开/关闭高级特征。我看不出 WGAN 如何以任何方式帮助避免这种情况。

    • Hashem Hashemi 2020 年 6 月 17 日下午 12:13 #

      顺便说一下,谢谢你的精彩演练。🙂 我见过最易懂的。

    • Jason Brownlee 2020 年 6 月 17 日下午 1:42 #

      这就是大部分科学:小小的调整!🙂

      为什么它会起作用?!对于许多事物,我们根本无法很好地回答这个问题。

      我甚至不知道我的汽车引擎为什么能工作。它里面可能有一个电脑。

  5. Parnian 2020 年 10 月 18 日下午 12:56 #

    嗨,Jason。感谢您的教程。我只是不明白为什么原始的损失版本不能在 Keras 中实现。

  6. ali 2020 年 11 月 1 日上午 8:57 #

    嗨,Jason,感谢您有用的帖子。
    我在这里找到了 wgan 损失函数的不同实现
    https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py#L134
    你有什么想法吗?

    • Jason Brownlee 2020 年 11 月 1 日下午 1:14 #

      抱歉,我没有,我建议直接与作者交流。

  7. Ailsor 2021 年 3 月 30 日晚上 10:40 #

    我们如何在半监督 GAN 中实现这一点?
    我正在遵循您在 https://machinelearning.org.cn/semi-supervised-generative-adversarial-network/ 中的教程

    • Jason Brownlee 2021 年 3 月 31 日上午 6:04 #

      抱歉,我无法为您编写代码。

      也许可以尝试修改教程以使用此替代损失。

  8. Ori 2021 年 6 月 8 日晚上 9:56 #

    您好,感谢您的文章。评论家损失函数的公式是错误的——它应该是相反的。

  9. Ori 2021 年 6 月 8 日晚上 9:58 #

    而且公式后的描述也是错误的。真实图像的分数需要更大,而虚假图像的分数需要更小。

  10. farnaz 2021 年 7 月 13 日下午 5:41 #

    您好,感谢您的教程,您有 WGAN_GP 损失的实现吗?

  11. Taylor 2023 年 8 月 24 日上午 7:44 #

    嗨,Jason。一如既往,感谢您的精彩文章。正如其他人指出的那样,如果我们假设我们将使用梯度下降来寻找局部最小值,那么其中一个损失函数的符号是错误的。由于原始论文(您展示的图像)中判别器使用梯度上升(第 6 行有加号)而生成器使用梯度下降(第 11 行有减号),如果我们要基于梯度下降进行优化,就像您描述的那样,那么您的判别器损失函数的符号需要反转。

    • James Carmichael 2023 年 8 月 24 日上午 8:59 #

      感谢您的反馈,Taylor!

  12. Francesco 2023 年 10 月 5 日晚上 9:09 #

    你好 Jason,

    感谢您提供如此精彩的教程!

    我尝试重新调整它,发现评论家损失和生成器损失达到了大约 -2000 的值。这正常吗,还是我在实现中做错了什么?

    • James Carmichael 2023 年 10 月 6 日上午 9:10 #

      嗨 Francesco…你是否复制粘贴了代码,还是手动输入的?另外,你做了哪些修改?

  13. Michio 2024 年 1 月 19 日上午 2:57 #

    嗨,Jason,
    感谢这篇教程。
    我只是想知道你为什么称 GAN 为零和博弈?
    我知道 GAN 玩的是迷你-最大游戏。
    判别器(原始 GAN)试图最大化其正确分类真实样本(类别 1)和合成样本(类别 0)的概率,而生成器试图最小化判别器将合成数据分类为合成(类别 0)的概率。
    因此,生成器学习真实数据的概率分布。
    因此,它生成一个值(正值),一个逼真的合成数据(正值),以至于判别器变得困惑(零值)。
    零和博弈整体不会产生任何正值(正值加上零)。

  14. Phil Troy 2024 年 7 月 6 日下午 2:14 #

    嗨 Jason

    WGAN 能否应用于从大量数字样本中生成单个数字?也就是说,如果我提供一个大量正态分布数字的样本(都来自具有相同均值和标准差的正态分布总体),WGAN 能否生成正态分布的数字?如果能,如何精确计算生成器生成的一整批数字和输入判别器的真实数据(来自上述总体)的损失函数?

    • James Carmichael 2024 年 7 月 7 日上午 7:25 #

      嗨 Phil…是的,Wasserstein 生成对抗网络(WGAN)可以应用于从分布中生成单个数字,例如正态分布的数字,如果您提供从该分布中提取的大量数字样本。以下是您可以解决此问题的方法

      ### 1. WGAN 概述
      WGAN 由两个神经网络组成
      – **生成器 (G)**:此网络生成新样本。
      – **判别器 (D) 或评论家**:此网络评估生成样本与真实样本相比的真实程度。

      WGAN 的目标是让生成器生成判别器无法区分真实样本的样本。

      ### 2. 使用 WGAN 生成正态分布数字

      #### 分步流程

      1. **数据准备**
      – 从具有已知均值 (\(\mu\)) 和标准差 (\(\sigma\)) 的正态分布中收集大量数字样本。这将是您的真实数据。

      2. **定义生成器和判别器**
      – 生成器网络接收随机噪声(例如,来自均匀分布或正态分布),并将其转换为应该类似于正态分布的数字。
      – 判别器网络根据真实数字评估生成的数字。

      3. **损失函数**
      – WGAN 使用 Wasserstein 损失,这有助于更稳定地训练生成器和判别器。

      #### 生成器和判别器架构
      为简单起见,假设两个网络都是简单的前馈神经网络。

      – **生成器**:接收一个随机噪声向量并输出一个数字。
      – **判别器**:接收一个数字并输出一个标量分数。

      #### Wasserstein 损失计算
      WGAN 使用 Wasserstein 距离作为衡量真实分布和生成分布之间差异的度量。

      – **判别器损失**:计算为真实数据和生成数据的平均分数之差。
      \[
      L_D = \mathbb{E}_{x \sim P_r} [D(x)] – \mathbb{E}_{z \sim P_z} [D(G(z))]
      \]
      其中 \(P_r\) 是真实数据分布,\(P_z\) 是噪声分布。

      – **生成器损失**:生成器旨在最大化判别器对生成样本的评分。
      \[
      L_G = -\mathbb{E}_{z \sim P_z} [D(G(z))]
      \]

      #### 训练过程
      1. **训练判别器**
      – 从你的正态分布中采样一批真实数字。
      – 采样一批噪声向量。
      – 使用生成器生成一批伪造数字。
      – 使用真实数字和伪造数字计算判别器损失。
      – 更新判别器参数以最大化判别器损失。

      2. **训练生成器**
      – 采样一批噪声向量。
      – 使用生成器生成一批伪造数字。
      – 使用伪造数字计算生成器损失。
      – 更新生成器参数以最小化生成器损失。

      #### Python 中的实现(带 PyTorch 示例)
      这是一个 PyTorch 中的简化实现

      python
      import torch
      import torch.nn as nn
      import torch.optim as optim

      # 定义生成器模型
      class Generator(nn.Module)
      def __init__(self)
      super(Generator, self).__init__()
      self.model = nn.Sequential(
      nn.Linear(1, 128),
      nn.ReLU(),
      nn.Linear(128, 1)
      )

      def forward(self, z)
      return self.model(z)

      # 定义判别器模型
      class Discriminator(nn.Module)
      def __init__(self)
      super(Discriminator, self).__init__()
      self.model = nn.Sequential(
      nn.Linear(1, 128),
      nn.ReLU(),
      nn.Linear(128, 1)
      )

      def forward(self, x)
      return self.model(x)

      # 超参数
      batch_size = 64
      lr = 0.0002
      num_epochs = 10000

      # 实例化模型
      G = Generator()
      D = Discriminator()

      # 优化器
      optimizer_G = optim.RMSprop(G.parameters(), lr=lr)
      optimizer_D = optim.RMSprop(D.parameters(), lr=lr)

      # 训练循环
      for epoch in range(num_epochs)
      for _ in range(5): # 更频繁地训练判别器
      # 采样真实数据
      real_data = torch.randn(batch_size, 1) * sigma + mu

      # 采样噪声
      z = torch.randn(batch_size, 1)

      # 生成假数据
      fake_data = G(z)

      # 判别器损失
      D_loss = -torch.mean(D(real_data)) + torch.mean(D(fake_data))

      # 更新判别器
      optimizer_D.zero_grad()
      D_loss.backward()
      optimizer_D.step()

      # 裁剪判别器权重
      for p in D.parameters()
      p.data.clamp_(-0.01, 0.01)

      # 生成器损失
      z = torch.randn(batch_size, 1)
      fake_data = G(z)
      G_loss = -torch.mean(D(fake_data))

      # 更新生成器
      optimizer_G.zero_grad()
      G_loss.backward()
      optimizer_G.step()

      if epoch % 1000 == 0:
      print(f"Epoch {epoch}, D_loss: {D_loss.item()}, G_loss: {G_loss.item()}")

      ### 总结
      – **是的**,您可以使用 WGAN 通过在来自该分布的大样本上进行训练来生成正态分布的数字。
      – **损失函数**基于 Wasserstein 距离,它们是
      – **判别器损失**:判别器对真实样本和假样本的评分之差。
      – **生成器损失**:判别器对生成样本的评分的负值。
      – **条件更新**和**权重裁剪**用于保持 Wasserstein 距离的特性。

      这个例子提供了一个使用 WGAN 从正态分布生成数字的基本框架。您可以调整模型架构、超参数和训练过程以适应您的特定需求和数据特征。

发表回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。