PyTorch 中 Softmax 分类器入门

虽然逻辑回归分类器用于二元分类,但 Softmax 分类器是一种有监督学习算法,主要用于涉及多个类别的情况。

Softmax 分类器通过为每个类别分配概率分布来工作。具有最高概率的类别的概率分布被归一化为 1,所有其他概率相应缩放。

同样,Softmax 函数将神经元的输出转换为类别的概率分布。它具有以下特性:

  1. 它与逻辑 Sigmoid 相关,逻辑 Sigmoid 用于概率建模并具有相似的特性。
  2. 它接收 0 到 1 之间的值,其中 0 对应于不可能发生的事件,1 对应于肯定会发生的事件。
  3. Softmax 关于输入 x 的导数可以解释为预测给定输入 x 时选择特定类别的可能性。

在本教程中,我们将构建一个一维 Softmax 分类器并探索其功能。特别是,我们将学习:

  • 如何使用 Softmax 分类器进行多类别分类。
  • 如何在 PyTorch 中构建和训练 Softmax 分类器。
  • 如何分析模型在测试数据上的结果。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

PyTorch 中的 Softmax 分类器入门。
图片来自 Julia Caesar。部分权利保留。

概述

本教程分为四个部分;它们是:

  • 准备数据集
  • 将数据集加载到 DataLoader 中
  • 使用 nn.Module 构建模型
  • 训练分类器

准备数据集

首先,让我们构建我们的数据集类来生成一些数据样本。与之前的实验不同,您将生成多类数据。然后,您将在这些数据样本上训练 Softmax 分类器,之后使用它来对测试数据进行预测。

下面,我们根据单个输入变量为四个类别生成数据。

让我们创建数据对象,并检查前十个数据样本及其标签。

输出如下:

使用 nn.Module 构建 Softmax 模型

我们将使用 PyTorch 的 nn.Module 来构建一个自定义的 Softmax 模块。它与您在之前的逻辑回归教程中构建的自定义模块类似。那么,这里的区别是什么呢?之前您在 n_ouputs 的位置使用了 1 进行二元分类,而在这里我们将定义四个类别用于多类别分类。其次,在 forward() 函数中,模型不使用逻辑函数进行预测。

现在,让我们创建模型对象。它接收一个一维向量作为输入,并为四个不同的类别进行预测。我们还可以看看参数是如何初始化的。

输出如下:

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

训练模型

结合随机梯度下降,我们将使用交叉熵损失进行模型训练,并将学习率设置为 0.01。我们将数据加载到数据加载器中,并将批量大小设置为 2。

一切准备就绪后,让我们将模型训练 100 个 epoch。

训练循环完成后,您调用模型的 max() 方法进行预测。参数 1 返回轴一的最大值,即返回每列中的最大值索引。

从上面,您应该会看到

这些是模型在测试数据上的预测。

让我们也检查一下模型的准确率。

在这种情况下,您可能会看到

在这个简单的模型中,如果您训练更长时间,您可以看到准确率接近 1。

把所有东西放在一起,下面是完整的代码。

总结

在本教程中,您学习了如何构建一个简单的一维 Softmax 分类器。特别是,您学习了:

  • 如何使用 Softmax 分类器进行多类别分类。
  • 如何在 PyTorch 中构建和训练 Softmax 分类器。
  • 如何分析模型在测试数据上的结果。

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

“PyTorch 中的 Softmax 分类器入门” 的 4 条回复

  1. John William O'Meara 2023年7月3日 上午3:39 #

    如果我说错了,请纠正我,但这似乎根本没有实现 Softmax 分类器。
    定义的名为“Softmax()”的自定义类只有一个线性层,没有 Softmax 激活。
    这是疏忽吗?

    • James Carmichael 2023年7月3日 上午7:59 #

      你好 John… 本教程介绍了概念。如果您使用它,请告诉我们,并在您自己的项目中分享您的发现。

      • JC 2024年11月24日 上午10:02 #

        您的输出维度是 4,但您的预测只需要 1 个数字,我看不到您如何从 4 得到 1。

        • James Carmichael 2024年11月25日 上午7:25 #

          你好 JC… 你能运行代码来确认你的输出吗?请及时告知你的进展。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。