如何从头开始使用 Keras 实现半监督 GAN (SGAN)

半监督学习是一个具有挑战性的问题,它涉及在只有少量标记样本和大量未标记样本的数据集上训练分类器。

生成对抗网络(GAN)是一种通过图像判别器模型有效地利用大型未标记数据集来训练图像生成器模型的架构。在某些情况下,判别器模型可以用作开发分类器模型的起点。

半监督生成对抗网络(SGAN)模型是GAN架构的扩展,它涉及同时训练监督判别器、无监督判别器和生成器模型。其结果是,既能得到一个泛化能力强的监督分类模型,也能得到一个能生成该领域图像的合理示例的生成器模型。

在本教程中,您将学习如何从头开始开发一个半监督生成对抗网络。

完成本教程后,您将了解:

  • 半监督GAN是GAN架构的扩展,用于在利用标记和未标记数据训练分类器模型。
  • 在 Keras 中实现半监督 GAN 所使用的监督和无监督判别器模型,至少有三种方法。
  • 如何从头开始在 MNIST 上训练半监督 GAN,并加载和使用训练好的分类器进行预测。

开始您的项目,阅读我的新书 《Python生成对抗网络》,其中包含分步教程以及所有示例的Python源代码文件。

让我们开始吧。

How to Implement a Semi-Supervised Generative Adversarial Network From Scratch

如何从头开始实现半监督生成对抗网络。
照片由 Carlos Johnson 拍摄,保留部分权利。

教程概述

本教程分为四个部分;它们是

  1. 什么是半监督 GAN?
  2. 如何实现半监督判别器模型
  3. 如何为 MNIST 开发半监督 GAN
  4. 如何加载和使用最终的 SGAN 分类器模型

什么是半监督 GAN?

半监督学习是指在需要预测模型但只有少量标记示例和大量未标记示例的问题。

最常见的例子是分类预测建模问题,其中可能有一个非常大的数据集,但只有一小部分具有目标标签。模型必须从少量标记示例中学习,并以某种方式利用更多的未标记示例数据集,以便泛化到将来对新示例进行分类。

半监督 GAN,有时简称为 SGAN,是生成对抗网络架构的扩展,用于解决半监督学习问题。

这项工作的主要目标之一是提高生成对抗网络在半监督学习中的有效性(通过在额外的未标记示例上学习来提高监督任务,在本例中为分类,的性能)。

——《改进GAN训练技术》,2016。

传统 GAN 中的判别器经过训练,用于预测给定图像是真实的(来自数据集)还是伪造的(生成的),从而使其能够从无标记图像中学习特征。然后,可以通过 迁移学习 将判别器用作开发同一数据集的分类器的起点,从而使监督预测任务受益于 GAN 的无监督训练。

在半监督 GAN 中,判别器模型被更新以预测 K+1 个类别,其中 K 是预测问题中的类别数,并为新的“伪造”类别添加额外的类别标签。它涉及同时直接训练判别器模型以进行无监督 GAN 任务和监督分类任务。

我们在一个包含 N 个类别输入的 数据集上训练一个生成模型 G 和一个判别器 D。在训练时,D 被设置为预测输入属于 N+1 个类别中的哪一个,其中添加了一个额外的类别以对应于 G 的输出。

—— 《使用生成对抗网络进行半监督学习》,2016。

因此,判别器以两种模式进行训练:监督模式和无监督模式。

  • 无监督训练:在无监督模式下,判别器以与传统 GAN 相同的方式进行训练,以预测示例是真实的还是伪造的。
  • 监督训练:在监督模式下,判别器被训练来预测真实示例的类别标签。

在无监督模式下进行训练允许模型从大型未标记数据集中学习有用的特征提取能力,而在监督模式下进行训练则允许模型利用提取的特征并应用类别标签。

结果是一个分类器模型,当在很少的标记示例(例如几十、几百或一千个)上进行训练时,可以在 MNIST 等标准问题上取得最先进的结果。此外,训练过程还可以产生更高质量的生成器模型输出图像。

例如,Augustus Odena 在其 2016 年题为“《使用生成对抗网络进行半监督学习》”的论文中展示了,在 MNIST 手写数字识别任务上,当使用 25、50、100 和 1000 个标记示例进行训练时,GAN 训练的分类器能够与独立的 CNN 模型媲美甚至表现更好。

Example of the Table of Results Comparing Classification Accuracy of a CNN and SGAN on MNIST

MNIST 上 CNN 和 SGAN 分类准确度比较结果表示例。
来源:使用生成对抗网络进行半监督学习

OpenAI 的 Tim Salimans 等人在其 2016 年题为“《改进 GAN 训练技术》”的论文中,使用半监督 GAN 在包括 MNIST 在内的多个图像分类任务上取得了当时的最佳结果。

Example of the Table of Results Comparing Classification Accuracy of other GAN models to a SGAN on MNIST

MNIST 上其他 GAN 模型与 SGAN 分类准确度比较结果表示例。
来源:改进 GAN 训练技术

想从零开始开发GAN吗?

立即参加我为期7天的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

如何实现半监督判别器模型

我们可以通过多种方式为半监督 GAN 实现判别器模型。

在本节中,我们将回顾三种候选方法。

传统判别器模型

考虑标准 GAN 模型的判别器模型。

它必须接受图像作为输入,并预测它是真实的还是伪造的。更具体地说,它预测输入图像为真实的概率。输出层使用 sigmoid 激活函数来预测 [0,1] 范围内的概率值,并且模型通常使用二元交叉熵损失函数进行优化。

例如,我们可以定义一个简单的判别器模型,它接受 28x28 像素的灰度图像作为输入,并预测图像为真实的概率。我们可以遵循最佳实践,使用具有 2x2 步长Leaky ReLU 激活函数 的卷积层来对图像进行下采样。

下面的 `define_discriminator()` 函数实现了这一点,并定义了我们的标准判别器模型。

运行该示例将创建判别器模型的图,清晰显示输入图像的 28x28x1 形状以及单个概率值的预测。

Plot of a Standard GAN Discriminator Model

标准 GAN 判别器模型图

具有共享权重的独立判别器模型

从标准 GAN 判别器模型开始,我们可以更新它以创建两个共享特征提取权重的模型。

具体来说,我们可以定义一个预测图像是真实还是伪造的分类器模型,以及第二个预测给定模型的类别的分类器模型。

  • 二元分类器模型。预测图像是真实的还是伪造的,输出层使用 sigmoid 激活函数,并使用 二元交叉熵损失函数 进行优化。
  • 多类别分类器模型。预测图像的类别,输出层使用 softmax 激活函数,并使用 分类交叉熵损失函数 进行优化。

这两个模型具有不同的输出层,但共享所有特征提取层。这意味着对其中一个分类器模型的更新将影响两个模型。

下面的示例首先创建具有二元输出的传统判别器模型,然后重用特征提取层并创建一个新的多类别预测模型,在本例中为 10 个类别。

运行该示例将创建并绘制两个模型。

第一个模型的图与之前相同。

Plot of an Unsupervised Binary Classification GAN Discriminator Model

无监督二元分类 GAN 判别器模型图

第二个模型的图显示了相同的预期输入形状和相同的特征提取层,并带有一个新的 10 类分类输出层。

Plot of a Supervised Multi-Class Classification GAN Discriminator Model

监督多类别分类 GAN 判别器模型图

具有共享权重的单个判别器模型

实现半监督判别器模型的另一种方法是拥有一个具有多个输出层的单一模型。

具体来说,这是一个具有一个输出层用于无监督任务,一个输出层用于监督任务的单一模型。

这类似于为监督和无监督任务拥有单独的模型,因为它们都共享相同的特征提取层,只是在这种情况下,每个输入图像总是具有两个输出预测,特别是实/伪预测和监督类别预测。

这种方法的一个问题是,当模型使用未标记和生成的图像进行更新时,没有监督类别标签。在这种情况下,这些图像的监督输出必须具有“未知”或“伪造”的输出标签。这意味着监督输出层需要一个额外的类别标签。

下面的示例实现了半监督 GAN 架构中判别器模型的多个输出的单一模型方法。

我们可以看到模型定义了两个输出层,监督任务的输出层定义为 n_classes + 1。在本例中为 11,为额外的“未知”类别标签留出了空间。

我们还可以看到模型被编译为两个损失函数,每个损失函数对应于模型的一个输出层。

运行该示例将创建并绘制单一的多输出模型。

该图清晰地显示了共享层以及独立的无监督和监督输出层。

Plot of a Semi-Supervised GAN Discriminator Model With Unsupervised and Supervised Output Layers

具有无监督和监督输出层的半监督 GAN 判别器模型图

堆叠的具有共享权重的判别器模型

最后一种方法与前两种方法非常相似,涉及创建独立的逻辑无监督和监督模型,但尝试将一个模型的输出层重用到另一个模型的输入。

该方法基于 OpenAI 的 Tim Salimans 等人在 2016 年发表的论文“《改进 GAN 训练技术》”中对半监督模型的定义。

在论文中,他们描述了一种高效的实现方式:首先创建具有 K 个输出类别和 softmax 激活函数的监督模型。然后定义无监督模型,该模型接受监督模型在 softmax 激活之前的输出,然后计算指数输出的归一化总和。

Example of the Output Function for the Unsupervised Discriminator Model in the SGAN

SGAN 中无监督判别器模型的输出函数示例。
来源:改进 GAN 训练技术

为了更清楚地说明这一点,我们可以用 NumPy 实现这个激活函数,并通过它运行一些示例激活来观察结果。

完整的示例如下所示。

请记住,无监督模型在 softmax 激活函数之前的输出将直接是节点的激活值。它们将是小的正值或负值,但未归一化,因为这将由 softmax 激活完成。

自定义激活函数将输出一个介于 0.0 和 1.0 之间的值。

对于小的或负的激活值,输出接近 0.0;对于大的或正的激活值,输出接近 1.0。当我们运行示例时,我们可以看到这一点。

这意味着模型被鼓励为真实示例输出一个强烈的类别预测,为伪造示例输出一个小的类别预测或低激活值。这是一个巧妙的技巧,允许在两个模型中重用监督模型相同的输出节点。

激活函数几乎可以直接通过 Keras 后端实现,并从 `Lambda` 层调用,例如,一个将自定义函数应用于该层输入的层。

完整的示例列在下面。首先,用 softmax 激活和分类交叉熵损失函数定义监督模型。无监督模型堆叠在监督模型输出层之上(在 softmax 激活之前),并且节点激活值通过 Lambda 层通过我们的自定义激活函数。

由于我们已经归一化了激活值,因此不需要 sigmoid 激活函数。与之前一样,无监督模型使用二元交叉熵损失进行拟合。

运行示例会创建并绘制两个模型,它们看起来与第一个示例中的两个模型非常相似。

无监督判别器模型的堆叠版本

Plot of the Stacked Version of the Unsupervised Discriminator Model of the Semi-Supervised GAN

半监督GAN的无监督判别器模型堆叠版本的图

有监督判别器模型的堆叠版本

Plot of the Stacked Version of the Supervised Discriminator Model of the Semi-Supervised GAN

半监督GAN的监督判别器模型堆叠版本的图

现在我们已经了解了如何在半监督GAN中实现判别器模型,我们可以开发一个完整的图像生成和半监督分类的示例。

如何为 MNIST 开发半监督 GAN

在本节中,我们将为MNIST手写数字数据集开发一个半监督GAN模型。

该数据集有10个数字0-9的类别,因此分类器模型将有10个输出节点。模型将拟合包含60,000个样本的训练数据集。训练数据集中只有100张图像带有标签,每个类别10张。

我们将首先定义模型。

我们将使用堆叠判别器模型,完全按照上一节的定义。

接下来,我们可以定义生成器模型。在这种情况下,生成器模型将接收潜在空间中的一个点作为输入,并使用转置卷积层输出一个28x28的灰度图像。下面的define_generator()函数实现了这一点并返回定义的生成器模型。

生成器模型将通过无监督判别器模型进行拟合。

我们将使用复合模型架构,这在Keras实现中是训练生成器模型的常用方法。具体来说,使用了权重共享,其中生成器模型的输出直接传递给无监督判别器模型,并且判别器的权重被标记为不可训练。

下面的define_gan()函数实现了这一点,它接受已定义的生成器和判别器模型作为输入,并返回用于训练生成器模型权重的复合模型。

我们可以加载训练数据集并缩放像素到[-1, 1]的范围,以匹配生成器模型的输出值。

我们还可以定义一个函数来选择训练数据集的一个子集,其中保留标签并训练判别器的有监督版本。

下面的select_supervised_samples()函数实现了这一点,并仔细确保样本的选择是随机的并且类别是平衡的。标记样本的数量被参数化并设置为100,这意味着10个类别中的每个类别将有10个随机选择的样本。

接下来,我们可以定义一个函数来检索真实训练样本的批次。

选择图像和标签的样本,并进行替换。当我们在训练模型时,可以稍后使用相同的函数从标记和未标记的数据集中检索样本。对于“未标记数据集”,我们将忽略标签。

接下来,我们可以定义用于通过生成器模型生成图像的函数。

首先,generate_latent_points()函数将创建一批随机点在潜在空间中,这些点可以用作生成图像的输入。generate_fake_samples()函数将调用此函数来生成一批图像,这些图像可以在训练期间馈送到无监督判别器模型或复合GAN模型。

接下来,我们可以定义一个函数,在需要评估模型性能时调用。

此函数将使用生成器模型的当前状态生成并绘制100张图像。这张图像图可用于主观评估生成器模型的性能。

然后,在整个训练数据集上评估有监督的判别器模型,并报告分类准确率。最后,将生成器模型和有监督的判别器模型保存到文件,以备将来使用。

下面的summarize_performance()函数实现了这一点,并且可以定期调用,例如在每个训练周期结束时。可以在运行结束时回顾结果,以选择一个分类器模型,甚至生成器模型。

接下来,我们可以定义一个函数来训练模型。定义的模型和加载的训练数据集作为参数提供,训练周期数和批次大小被参数化并具有默认值,在本例中为20个周期和100个批次大小。

选择的模型配置发现很快就过拟合了训练数据集,因此训练周期数相对较少。将周期数增加到100或更多可以生成质量高得多的图像,但分类器模型的质量会降低。平衡这两者可能是一个有趣的扩展。

首先,选择训练数据集的标记子集,并计算训练步数。

训练过程与训练标准GAN模型几乎相同,只是增加了使用标记示例更新有监督模型。

更新模型的单个周期包括首先使用标记示例更新有监督的判别器模型,然后使用未标记的真实和生成示例更新无监督的判别器模型。最后,通过复合模型更新生成器模型。

判别器模型的共享权重使用1.5批次的样本进行更新,而生成器模型的权重在每次迭代中都使用一个批次的样本进行更新。将此更改为使每个模型由相同数量更新,可能会改进模型训练过程。

最后,我们可以定义模型并调用函数来训练和保存模型。

将所有这些内容结合起来,在 MNIST 手写数字图像分类任务上训练半监督 GAN 的完整示例如下。

该示例可以在配备 CPU 或 GPU 硬件的工作站上运行,但建议使用 GPU 以获得更快的执行速度。

注意:由于算法或评估程序的随机性,或数值精度的差异,您的结果可能会有所不同。请考虑运行该示例几次并比较平均结果。

在运行开始时,将总结训练数据集的大小,以及监督子集的大小,以确认我们的配置。

每个模型的性能在每次更新结束时都会进行总结,包括监督判别器模型(c)的损失和准确率,无监督判别器在真实和生成样本上的损失(d),以及通过复合模型更新的生成器模型的损失(g)。

监督模型的损失将收缩到一个接近零的小值,准确率将达到 100%,并在整个运行过程中保持不变。如果无监督判别器和生成器保持平衡,它们的损失应该在整个运行过程中保持在适中的值。

监督分类模型在每个训练 epoch 结束时(在本例中是每 600 次训练更新后)在整个训练数据集上进行评估。此时,模型的性能得以总结,显示其能够快速获得良好的技能。

考虑到模型仅在每类 10 个标记示例上进行训练,这令人惊讶。

模型也会在每个训练 epoch 结束时保存,并且还会生成生成的图像的图。

与相对较少的训练 epoch 相比,生成的图像质量很好。

Plot of Handwritten Digits Generated by the Semi-Supervised GAN After 8400 Updates.

经过 8400 次更新后,半监督 GAN 生成的手写数字图。

如何加载和使用最终的 SGAN 分类器模型

现在我们已经训练了生成器和判别器模型,我们可以利用它们了。

在半监督 GAN 的情况下,我们对生成器模型的兴趣较小,而对监督模型的兴趣较大。

回顾特定运行的结果,我们可以选择一个已知在测试数据集上具有良好性能的特定保存模型。在本例中,是在 12 个训练 epoch 或 7,200 次更新后保存的模型,该模型在训练数据集上的分类准确率为 95.432%。

我们可以通过 Keras 的 load_model() 函数直接加载模型。

加载后,我们可以再次在整个训练数据集上评估它以确认发现,然后在其上评估保留的测试数据集。

请记住,特征提取层期望输入图像的像素值缩放到 [-1,1] 的范围,因此,在将任何图像提供给模型之前必须进行此操作。

加载已保存的半监督分类器模型并在完整的 MNIST 数据集上对其进行评估的完整示例如下。

运行该示例将加载模型并在 MNIST 数据集上对其进行评估。

注意:由于算法或评估程序的随机性,或数值精度的差异,您的结果可能会有所不同。请考虑运行该示例几次并比较平均结果。

我们可以看到,在本例中,模型在训练数据集上达到了预期的 95.432% 的性能,证实我们已加载了正确的模型。

我们还可以看到,在保留的测试数据集上的准确率同样良好,甚至略好,约为 95.920%。这表明学习到的分类器具有良好的泛化能力。

我们已经成功演示了通过 GAN 架构拟合的半监督分类器模型的训练和评估。

扩展

本节列出了一些您可能希望探索的扩展本教程的想法。

  • 独立分类器。在标记数据集上直接拟合一个独立的分类器模型,并将其性能与 SGAN 模型进行比较。
  • 标记示例数量。重复使用更多或更少的标记示例的示例,并比较模型的性能。
  • 模型调优。调整判别器和生成器模型的性能,以进一步提升监督模型的性能,使其更接近最先进的结果。

如果您探索了这些扩展中的任何一个,我很想知道。
请在下面的评论中发布您的发现。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

论文

API

文章

项目

总结

在本教程中,您将学习如何从头开始开发一个半监督生成对抗网络。

具体来说,你学到了:

  • 半监督GAN是GAN架构的扩展,用于在利用标记和未标记数据训练分类器模型。
  • 在 Keras 中实现半监督 GAN 所使用的监督和无监督判别器模型,至少有三种方法。
  • 如何从头开始在 MNIST 上训练半监督 GAN,并加载和使用训练好的分类器进行预测。

你有什么问题吗?
在下面的评论中提出你的问题,我会尽力回答。

立即开发生成对抗网络!

Generative Adversarial Networks with Python

在几分钟内开发您的GAN模型

...只需几行python代码

在我的新电子书中探索如何实现
使用 Python 构建生成对抗网络

它提供了关于以下内容的自学教程端到端项目
DCGAN条件GAN图像翻译Pix2PixCycleGAN
以及更多...

最终将GAN模型引入您的视觉项目

跳过学术理论。只看结果。

查看内容

125 条对*如何在 Keras 中从零开始实现半监督 GAN (SGAN)* 的回复

  1. Vineeta 2019 年 8 月 9 日晚上 8:54 #

    SGAN 的概念解释得非常好。您能否在上述代码中加入特征匹配?

  2. FAD 2019 年 8 月 31 日晚上 9:11 #

    感谢 Jason 的精彩解释。我运行了代码,执行时间很长。
    您认为,在这么长的执行时间内,SGAN 能否与 VAT 等其他类型的半监督学习竞争?

    • Jason Brownlee 2019 年 9 月 1 日上午 5:39 #

      干得好!

      这确实取决于问题的具体情况和模型。测试各种方法是个好主意。

  3. Hamed 2019 年 9 月 9 日下午 4:16 #

    太棒了 Jason!但是,如果我需要使用一个不在后端中的自定义函数,而不是自定义激活函数呢?例如,我生成了一个人脸图像作为输出,我想计算它的 FaceNet 嵌入,并在损失函数中惩罚生成和实际嵌入之间的差异。我认为我可以通过 tf.session.run 将其更改为数组,然后计算其 FaceNet 嵌入,但它抛出了一个错误,说我必须为占位符张量提供一个值。

    • Jason Brownlee 2019 年 9 月 10 日上午 5:36 #

      非常酷的想法!

      您可以使用自定义函数,但请记住您将使用张量而不是数组。因此,所有简单的操作都必须使用后端函数,这些函数将包装张量的 TF/Theano 函数。

      可能需要一些实验。

      • Hamed 2019 年 9 月 11 日下午 3:48 #

        你的意思是,我需要改变 Facenet 预测的所有代码吗?或者,如果我想使用 dlib 来提取每个张量的 68 个地标,我需要改变它们的代码。在他们的代码中,他们明确表示支持列表或数组作为输入。还有其他方法可以解决这个问题吗?

        • Jason Brownlee 2019 年 9 月 12 日上午 5:12 #

          不,我的意思是,如果您使用自定义函数,您将处理张量而不是数组。

  4. AB 2019 年 11 月 16 日上午 9:00 #

    您使用什么版本的 Keras 和 TensorFlow?

  5. Yingbo 2019 年 11 月 22 日上午 4:40 #

    # 更新监督判别器(c)
    [Xsup_real, ysup_real], _ = generate_real_samples([X_sup, y_sup], half_batch)
    c_loss, c_acc = c_model.train_on_batch(Xsup_real, ysup_real)

    如果我们想训练监督判别器,我们需要使用标记数据,对吗?我的意思是,我们应该使用

    [Xsup_real, ysup_real], _ = select_supervised_samples(datasets)

    而不是
    [Xsup_real, ysup_real], _ = generate_real_samples([X_sup, y_sup], half_batch)

    我很困惑,我错了还是没有?谢谢!

    • Yingbo 2019 年 11 月 22 日上午 4:46 #

      抱歉,是我的错,没关系!您的教程太棒了,真的帮了我很多!

  6. YAMIN 2019 年 12 月 9 日凌晨 1:05 #

    您能给我们提供一个关于 CatGAN 的教程吗?

  7. Double H 2019 年 12 月 19 日晚上 11:51 #

    感谢您精彩的讲解!
    我想用您的 SGAN 代码进行 2D 图像分割。
    由于我刚接触深度学习,所以想问一些问题。

    1. 将判别器当前的 1D 输出层修改为 2D 输出层是否适用于分割?

    2. 如果您知道,能否告诉我一些用于 2D 图像的 Keras 分割代码的良好示例?

    3. 如果您将代码修改为用于 2D 图像分割,您会怎么做?

    再次感谢您,祝您有美好的一天!

    • Jason Brownlee 2019 年 12 月 20 日上午 6:50 #

      抱歉,我不知道如何直接将此示例应用于分割。

  8. Zohar RImon 2019 年 12 月 26 日上午 9:26 #

    感谢您提供的精彩文章!
    我研究了您提到的两个扩展,并将 SGAN 的性能与独立分类器进行了比较。我每次都使用不同数量的标记数据来训练 SGAN 和独立分类器。
    我还为独立分类器和 SGAN 添加了数据增强(主要是因为独立分类器在没有数据增强的情况下性能非常低)。
    结果基本符合预期,但确实显示了 SGAN 在少量监督数据下的影响。
    代码和结果(生成的图)可在我的 GitHub 存储库中找到 – https://github.com/zoharri/SGAN_vs_Classifier

    • Jason Brownlee 2019 年 12 月 27 日上午 6:28 #

      谢谢。

      您做的扩展很棒!感谢分享。

  9. Kum 2020 年 1 月 15 日下午 5:13 #

    你好 Json,感谢您的解释。

    SGAN 中的判别器多分类器实例是独立训练的,并且不与生成器连接。您能否解释一下它与普通 CNN 分类器的区别?

  10. Nel 2020 年 2 月 14 日凌晨 12:07 #

    感谢这篇很棒的文章!

    我对于如何将输入馈送到具有多个输出的单个判别器模型感到困惑。假设我们有 X_train(标记)和 X_test(未标记)数据。X_test 的输出应该馈送到“d_output_layer”,X_train 的输出应该馈送到“c_output_layer”。那么,我们该如何处理?我的意思是,如何馈送数据并从所需层获取输出。

    谢谢

  11. sasi 2020 年 4 月 1 日凌晨 2:09 #

    哪种 Python、Keras 和 TensorFlow 版本适合运行上述代码?

    我为 python 3.6 遇到了以下错误

    ImportError: Could not find ‘msvcp140.dll’. TensorFlow requires that this DLL be installed in a directory that is named in your %PATH% environment variable. You may install this DLL by downloading Visual C++ 2015 Redistributable Update 3 from this URL: https://www.microsoft.com/en-us/download/details.aspx?id=53587

    • Jason Brownlee 2020 年 4 月 1 日上午 5:53 #

      Keras 2.3 和 TensorFlow 2。

      • sasi 2020 年 4 月 6 日下午 3:33 #

        感谢您抽出宝贵时间为我解答疑问。

  12. Aswin Shriram T 2020 年 4 月 3 日下午 7:48 #

    嗨,Jason,

    感谢您提供这个精彩的教程。

    我尝试将其应用于我的用例,但生成器为所有类生成相同的图像。您知道这是为什么吗?

    • Jason Brownlee 2020 年 4 月 4 日上午 6:17 #

      不客气。

      也许可以尝试更改模型的配置以适应您的数据?

  13. Aswin Shriram T 2020 年 4 月 4 日下午 3:42 #

    我确实尝试增加了每类的样本数量。还有其他可以改进这种情况的参数吗?

  14. Saumya 2020 年 4 月 25 日上午 4:18 #

    你好!感谢您的精彩解释!

    我有一个关于共享权重的堆叠判别器的问题。如果监督和无监督判别器在激活层之前共享权重。那么,为什么 c_model 有 316,938 个权重,而 d_model 有 633,876 个权重?在我看来,c_model 和 d_model 都共享公共层,并应用不同的激活函数来创建不同的模型。那么它们的权重数量应该相同吗?

    谢谢!

  15. Ambuje Gupta 2020 年 5 月 8 日下午 1:19 #

    你好,
    这是一篇很棒的文章,我能够在我的应用程序中实现它。我不太清楚应该保存哪个模型来生成图像?您最后是如何生成图像的?我应该保存哪个模型来做同样的事情?通常,我们会保存生成器模型,但我在这里有点困惑。

    谢谢 🙂

    • Jason Brownlee 2020年5月8日下午3:55 #

      谢谢!

      生成器模型用于生成图像。

      是的,我们保存生成器模型。

  16. lxxian 2020年6月9日下午10:41 #

    您好,您的文章非常好,对我有帮助!
    我有一个小问题,如果这个模型可以转换为传统的数据形式,也就是向量形式,该如何改变呢?

    谢谢!

  17. Hitesh Tekchandani 2020年6月11日上午6:04 #

    本教程中的分类器准确性是什么意思?是说分类器识别样本标签(例如0-9)的效率如何?还是分类器准确性表示分类器区分真实样本和生成(或假)样本的效率如何?此外,分类器准确性和训练准确性是一样的吗?

    • Jason Brownlee 2020年6月11日上午6:07 #

      我们在这里训练一个生成器和一个分类模型。准确性是指分类模型。

      也许重读一下“什么是半监督 GAN?”这一节。

      • Hitesh Tekchandani 2020年6月11日下午5:11 #

        如何获得生成器生成的图像的类别标签?

        • Jason Brownlee 2020年6月12日上午6:09 #

          您可以使用分类模型对生成的图像进行分类。

          • Hitesh Tekchandani 2020年6月12日下午10:07 #

            谢谢,这句话对我来说是有用的:“c_out_layer = Activation(‘softmax’)(fe)”

          • Jason Brownlee 2020年6月13日上午6:02 #

            太棒了!

  18. Asma 2020年9月4日上午9:59 #

    如何选择 latent_dim?它属于数据集吗?
    我正在尝试将此应用于另一个数据集,那么如何调整另一个数据集的 latent_dim 值?因为我使用了值=100,结果准确率很差,只有53%,是其他问题吗?

    谢谢,

    • Jason Brownlee 2020年9月4日下午1:36 #

      通常很小。模型对尺寸不太敏感,因为它们会施加自己的结构。

      • Asma 2020年9月5日凌晨3:56 #

        谢谢,我正在用我的数据尝试您的代码,我的图像尺寸是 256x256x3,而且准确率很低,所以我想知道问题是出在分类模型还是潜在维度上,因为我也显示了生成的图像,而且它们在多次迭代后质量非常差?

        • Jason Brownlee 2020年9月5日上午6:53 #

          准确性是 GAN 的一个糟糕指标,直接查看生成的图像。

          也许这里的一些建议会有帮助
          https://machinelearning.org.cn/how-to-code-generative-adversarial-network-hacks/

          • P.G 2020年9月15日上午6:37 #

            您说训练超过 100 个 epoch 会提高图像生成但不会提高分类(在第 12 个 epoch 达到最高分类器准确率的例子中),现在您又说图像质量对准确性很重要?这里的诀窍是什么????? 您能否就准确性与图像生成提供更多见解?顺便说一句,这个页面上的所有作品都太棒了,您太棒了,谢谢,继续加油!

          • Jason Brownlee 2020年9月15日上午7:41 #

            抱歉,我的评论是关于 GAN 的通用评论,例如在使用 GAN 进行图像生成时。

  19. Any 2020年9月9日上午6:05 #


    很棒的文章,我学到了很多。
    有一件事对我来说不清楚。在训练生成器时,为什么您要生成标签为 1?

    X_gan, y_gan = generate_latent_points(latent_dim, n_batch), ones((n_batch, 1))

  20. Dmitriy 2020年9月23日上午6:08 #

    嗨,Jason,
    非常感谢您提供清晰有用的文章!
    我遇到了一个奇怪的问题,当时我在 TensorFlow (2.3.0) 中改编了这个例子——模型在单个 epoch/600 步内就能训练到合理的准确率,但如果我保存然后重新加载它,准确率就会变成随机的 (~10%)。
    您知道可能是什么问题吗?


    # 训练模型
    train(g_model, d_model, c_model, gan_model, dataset, latent_dim, n_epochs=1)

    # 加载数据集
    (trainX, trainy) = dataset
    _, train_acc = c_model.evaluate(trainX, trainy, verbose=0)
    print(‘Final Accuracy: %.3f%%’ % (train_acc * 100)) # 这个值 >80%

    ### 从文件中加载 c_model 并将其应用于相同数据
    from tensorflow.keras.models import load_model
    model = load_model(‘res/c_model_0600_sm’) #.h5
    _, train_acc = model.evaluate(trainX, trainy, verbose=0)
    print(‘Final Accuracy: %.3f%%’ % (train_acc * 100)) # 这个值 <10%

    • Jason Brownlee 2020年9月23日上午6:45 #

      也许可以尝试使用独立的 Keras 库并比较您系统上的结果?

  21. Mohamed Amin 2020年9月23日下午10:08 #

    你好!

    为什么我们将 d_model.trainable 设置为 False?我以为我们在训练 d_model,并在半监督 GAN 中禁用生成器。

    提前感谢您!

  22. AShir 2020年9月24日下午12:37 #

    嗨,Jason,

    感谢这次精彩的解释!

    – 您能否介绍一下“使用 CNN 的自监督图像分类”的文章/资源?
    考虑到我们的数据集没有任何标签(注释过的标签),并且我们想对图像进行分类。您会怎么解决?

    感谢您的指导,

  23. kevin 2020年10月2日上午8:13 #

    您好,您能提供一张像这张一样的架构图吗?-> https://media.arxiv-vanity.com/render-output/3592810/semi_gans.png

    谢谢!

  24. Divine 2021年1月23日下午6:06 #

    您好,我觉得这个概念非常有趣。但是,我正在尝试运行代码以获得视觉体验,但似乎有一些错误消息。另外,我想了解一下 SGAN 是否适用于分类真实和虚假的指纹,因为我目前正在研究这个。谢谢。

    • Jason Brownlee 2021年1月24日上午5:58 #

      谢谢。

      为什么不直接使用多层感知机模型?为什么要使用 GAN?

  25. Nhung Nguyen 2021年3月1日 下午1:18 #

    嗨,Jason,

    我实现了您的代码。但我不知道为什么当我加载模型并测试模型时,准确率是 9.99%。

    • Jason Brownlee 2021年3月1日 下午1:46 #

      也许仔细检查一下您是否复制了所有代码?
      也许尝试重新拟合模型?
      也许检查库版本号?

      • Nhung nguyen 2021年3月4日下午2:32 #

        这是训练过程
        >11988, c[0.001,100], d[0.742,0.858], g[1.056]
        >11989, c[0.001,100], d[0.828,0.937], g[0.922]
        >11990, c[0.001,100], d[0.970,0.843], g[0.903]
        >11991, c[0.001,100], d[0.792,0.848], g[1.147]
        >11992, c[0.001,100], d[0.944,0.992], g[1.223]
        >11993, c[0.002,100], d[0.712,0.919], g[1.263]
        >11994, c[0.002,100], d[0.667,0.846], g[1.177]
        >11995, c[0.002,100], d[0.923,0.911], g[1.162]
        >11996, c[0.001,100], d[0.916,0.775], g[1.115]
        >11997, c[0.001,100], d[0.799,0.638], g[0.975]
        >11998, c[0.002,100], d[0.837,0.939], g[0.914]
        >11999, c[0.001,100], d[0.810,0.816], g[0.961]
        >12000, c[0.001,100], d[0.676,0.928], g[1.012]
        分类器准确率:92.422%
        >已保存:generated_plot_12000.png, g_model_12000.h5, and c_model_12000.h5

        但当我加载模型时,结果是这样的
        警告:tensorflow:加载已保存优化器状态时出错。因此,您的模型将从一个新初始化的优化器开始。
        训练准确率:9.690%
        测试准确率:9.720%

  26. Nhung nguyen 2021年3月4日下午2:29 #

    先生,您好,

    我想感谢您的回复,

    我用 MNIST 数据集运行了您的代码,它逐行完全一致,并且训练时的准确率是正确的,具有高准确率,但在我加载模型并测试之后却有所不同。我不确定原因。当我尝试使用另一个数据集时,我也遇到了类似的错误。

    • Jason Brownlee 2021年3月5日上午5:30 #

      这太奇怪了!我以前从未见过这种情况。

      我想知道模型是否没有正确保存/加载,例如,也许在保存前后检查权重值。

      • Nhung nguyen 2021年3月9日下午3:50 #

        您好,非常感谢您的评论。
        我检查了 TensorFlow 和 Keras,现在结果没问题了。
        再次非常感谢。

  27. Guilherme Andrade 2021年3月25日晚上11:42 #

    您好,当您设置 d_model.trainable = False 时,这是否不会阻止判别器学习?也许 train_on_batch 会覆盖这一点,但我正在使用 tf.gradienttape 尝试,并且除非我在训练过程中将 d_model.trainable 设置为 True,否则我无法更改模型的权重。

    • Jason Brownlee 2021年3月26日上午6:25 #

      不行。

      它只在复合模型中生效。您可以在 API 文档中了解有关层冻结的更多信息。

      您仍然可以手动更改权重——“trainable”标志被 Keras API 在调用 fit() 等内部尊重。

  28. Guilherme Andrade 2021年3月26日凌晨1:46 #

    你好,
    通过将 d_model.trainable 设置为 False,这是否阻止了判别器和分类器学习?至少我尝试过使用 GradientTape。也许 train_on_batch 会覆盖它?是我理解错了什么吗?

  29. Wakil Khan 2021年3月28日凌晨12:04 #

    如何加载自定义数据集?假设我以前在我的机器上有一个图像数据集。

  30. Wakil Khan 2021年3月31日下午7:51 #

    我没明白潜空间到底是什么?

  31. kevin 2021年4月11日上午11:13 #

    您好,当我使用自定义激活时,我的判别器模型的损失没有改善。

    还有,为什么我们需要在训练集而不是验证集上评估数据?

  32. kevin 2021年4月11日上午11:27 #

    您好,我有一些问题。

    1. 为什么我们需要在训练集上评估而不是验证集?
    2. 我不知道为什么,但当我使用 custom_activation 时,监督判别器的损失没有改善。

    谢谢

  33. kevin 2021年4月25日下午1:37 #

    感谢您的回复,您是最棒的。

    我还有另一个问题。您为什么要使用 Adam 并设置 lr = 0.0002 和 beta = 0.5?这会影响 GAN 和结果的质量吗?

    • Jason Brownlee 2021年4月26日上午5:34 #

      不客气。

      是的,我相信这种 Adam 配置通常是推荐的,并且对于 GAN 模型普遍有效。

  34. kevin 2021年4月29日下午1:54 #

    当我使用自定义激活时,我的判别器对真实样本的损失总是太高(约 0.75~),准确率仅为 0.02。而我的判别器对假样本的损失较低(约 0.3~),准确率接近 0.98。这正常吗?

    • Jason Brownlee 2021年4月30日上午6:01 #

      也许可以与您特定数据集上的其他模型、其他配置进行比较。

  35. Juan Monte 2021年5月13日下午7:48 #

    再次感谢,Jason!

    一如既往,非常感谢您在这里分享这些信息。

    我有一个疑问:是否可以将预训练模型(例如 VGG16)包含进生成器?这样可以帮助生成器利用预训练模型提取的特征来生成更好的图像。

    感谢您的任何回复。

    • Jason Brownlee 2021年5月14日上午6:25 #

      不客气。

      您也许可以使用预训练模型,但我预计性能会更差。您可以尝试一下。

  36. Nate Yeli 2021年6月24日上午11:24 #

    嗨,Brownlee先生,

    非常感谢您提供这个资源。我想知道如何使用自定义数据集(96×96 图像,13 个类别)来运行这个示例。

    我该如何更新 n_nodes = 128 * 7 * 7,以及 gen = Reshape((7, 7, 128))(gen),和 gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding=’same’)(gen)?

    以及我该如何更改 in_shape=(28,28,1)?

    我查阅了 https://machinelearning.org.cn/how-to-load-and-manipulate-images-for-deep-learning-in-python-with-pil-pillow/,但仍然有些卡住。关于这个示例,任何帮助都会很感激。

    • Jason Brownlee 2021年6月25日上午6:09 #

      抱歉,我不确定在这种情况下该如何提供指导。您可能需要进行一些尝试和错误来调整模型以适应您的数据集。

  37. Melvin 2021年6月25日下午2:02 #

    自定义激活函数是否可以用于二分类问题,还是只能用于多类问题?

  38. Kevin 2021年8月2日凌晨2:42 #

    当我绘制 discriminator.summary() 时,它显示我有不可训练的参数,这正常吗?

    • Jason Brownlee 2021年8月2日上午4:54 #

      这看起来确实很奇怪。

      • Kevin 2021年8月2日下午1:37 #

        您可以做一个测试吗?我在这里测试过,当调用 build_gan() 时就会发生这种情况。也许当 train_on_batch 时,模型即使 trainable=False 也会训练,但当我们使用 fit() 时就不行了。

  39. Kumar 2021年11月25日上午7:56 #

    你好,

    我正在将 GAN 用于全监督问题。但是,我的指标非常差。您能检查一下我下面的训练循环——特别是 criterion 部分吗?

    谢谢

    • Adrian Tam
      Adrian Tam 2021年11月25日 下午2:39 #

      在我看来似乎还可以。GAN可能需要很长时间才能训练好。

  40. Lovely 2021年12月20日 下午10:42 #

    很棒的文章。

    • James Carmichael 2021年12月24日 上午6:07 #

      感谢您的反馈和赞美!

      此致,

  41. Noman 2021年12月23日 上午8:06 #

    先生您好,我有一个非常基础的问题。我们如何计算神经网络的训练时间?

  42. Noman 2021年12月25日 上午6:43 #

    我正在使用 TensorFlow 和 Google Colab。

  43. Noman 2021年12月28日 上午12:16 #

    我的神经网络每个 epoch 需要 4 秒,每一步需要 2 毫秒。我的目标是计算神经网络的总训练时间。所以,我对此感到困惑。

  44. ZMB 2021年12月30日 下午10:11 #

    您好,您能分享一下此代码的 PyTorch 实现吗?

  45. Juan 2022年5月8日 上午3:41 #

    你好,

    在具有共享权重的独立判别器模型中,您是如何训练它的?

    我的意思是,如果分类器模型没有额外的标签来预测类别是否为“未知”,您只能在有标签的图像上训练该模型,对吗?

    那么您是先将所有图像(有标签和无标签的真实图像,以及生成的假图像)输入到第一个判别器,然后将少量有标签的图像输入到第二个判别器吗?

    谢谢。

  46. Mahdi Mohammadi 2022年6月17日 下午8:30 #


    在您写“bat_per_epo = int(dataset[0].shape[0] / n_batch)”的 train 函数的行中,
    难道不应该是
    bat_per_epo = int(X_sup.shape[0] / n_batch)
    ?
    因为 dataset[0] 包含整个 mnist x_train,而我们实际上只使用了其中的 100 张图片。

    • James Carmichael 2022年6月18日 上午10:46 #

      感谢您的反馈 Mahdi!

  47. Mohammed 2022年9月6日 下午10:53 #

    我们可以将代码用于另一个数据集(二分类)吗?如果可以,如何做?

  48. Arega 2022年11月9日 上午7:51 #

    我可以在多标签分类特征提取中使用 GAN 作为图像到 CSV 文件的判别器吗?

  49. Ashay 2023年3月26日 上午4:11 #

    您好,当您说您训练了 20 个 epoch,n_batch=100 时

    如果我想将其与基于 CNN 的分类器进行比较,我应该也运行它 20 个 epoch,还是其他数量?

    • James Carmichael 2023年3月26日 上午10:30 #

      您好 Ashay…您的理解是正确的!建议您按您所说的那样做。

      • Ashay 2023年3月26日 上午11:43 #

        谢谢,那么如果我的理解是正确的,即使步数约为 1000 步(batch_size=256),GAN 也只训练 20 个 epoch 吗?

        在这里,其他未标记的样本被用于训练整个数据集,然后生成模型可以学习更好的特征,然后分类器判别器模型训练 20 个 epoch,但当正常判别器更新权重时,它也会获得权重更新。

        • James Carmichael 2023年3月27日 上午10:42 #

          您好 Ashay…您说得对!请与我们分享您的 GAN 模型表现如何!

          • Ashay 2023年3月27日 下午8:28 #

            谢谢,我注意到对于有 10 个类别的医学图像数据,当每类至少有 300 个以上样本时,基线 CNN 的表现接近 ssgan。直到 100 个样本时,仍然有超过 5% 的准确率差异。

  50. ramsey morton 2023年3月26日 上午4:27 #

    就像 MNIST 一样,分类器模型是 7200c.h5,这是否意味着它比训练了 20 个 epoch 的传统 ML 分类器训练得更多?

  51. ramsey morton 2023年3月26日 上午4:28 #

    只需相应地更改损失函数、n_classes 和样本即可,例如这里的 100 可能意味着每个 50。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。