在 PyTorch 中使用交叉熵损失训练逻辑回归

在我们的 PyTorch 系列上一节中,我们展示了当使用均方误差 (MSE) 损失时,初始化不当的权重如何影响分类模型的准确性。我们注意到模型在训练过程中没有收敛,并且其准确性也显著降低。

接下来,您将看到如果随机初始化权重并使用交叉熵作为模型训练的损失函数会发生什么。这种损失函数更适合逻辑回归和其他分类问题。因此,交叉熵损失被用于当今大多数分类问题。

在本教程中,您将使用交叉熵损失训练一个逻辑回归模型,并对测试数据进行预测。具体来说,您将学习

  • 如何在 PyTorch 中使用交叉熵损失训练逻辑回归模型。
  • 交叉熵损失如何影响模型准确性。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

在 PyTorch 中使用交叉熵损失训练逻辑回归。
图片来源:Y K。保留部分权利。

概述

本教程分为三个部分;它们是

  • 准备数据和构建模型
  • 使用交叉熵进行模型训练
  • 使用测试数据进行验证

准备数据和模型

就像之前的教程一样,您将构建一个类来获取数据集以执行实验。该数据集将分为训练样本和测试样本。测试样本是用于衡量训练模型性能的未见数据。

首先,我们创建一个 Dataset

然后,实例化数据集对象。

接下来,您将为我们的逻辑回归模型构建一个自定义模块。它将基于 PyTorch 的 nn.Module 的属性和方法。这个包允许我们为我们的深度学习模型构建复杂的自定义模块,并使整个过程变得容易得多。

该模块只包含一个线性层,如下所示:

让我们创建模型对象。

该模型应该具有随机权重。您可以通过打印其状态来检查这一点

您可能会看到

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

使用交叉熵进行模型训练

回想一下,在之前的教程中,当您将这些参数值与 MSE 损失一起使用时,该模型没有收敛。让我们看看使用交叉熵损失时会发生什么。

由于您正在执行具有一个输出的逻辑回归,因此这是一个具有两个类别的分类问题。换句话说,这是一个二元分类问题,因此我们使用二元交叉熵。您按如下方式设置优化器和损失函数。

接下来,我们准备一个 DataLoader 并训练模型 50 个 epoch。

训练期间的输出将如下所示:

如您所见,损失在训练过程中减少并收敛到最小值。我们还将绘制训练图。

您将看到以下内容:

使用测试数据进行验证

上图显示模型在训练数据上表现良好。最后,让我们检查模型在未见数据上的表现。

结果如下:

当模型使用 MSE 损失进行训练时,表现不佳。之前它的准确率约为 57%。但在这里,我们得到了完美的预测。部分原因在于模型简单,是一个单变量逻辑函数。部分原因在于我们正确设置了训练。因此,如我们的实验所示,交叉熵损失显著提高了模型的准确性,优于 MSE 损失。

把所有东西放在一起,下面是完整的代码。

总结

在本教程中,您学习了交叉熵损失如何影响分类模型的性能。具体来说,您学习了

  • 如何在 PyTorch 中使用交叉熵损失训练逻辑回归模型。
  • 交叉熵损失如何影响模型准确性。

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

对《在 PyTorch 中使用交叉熵损失训练逻辑回归》的回复

  1. Michael 2023 年 3 月 28 日上午 7:31 #

    嗨,M

    代码不错,但我认为您需要对数据进行训练/测试拆分,以证明模型确实泛化良好。

发表评论

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。