PyTorch 中的训练和验证数据

训练数据是机器学习算法用来学习的数据集,也称为训练集。验证数据是机器学习算法用来测试其准确性的数据集之一。验证算法的性能是指将其预测输出与验证数据中的已知真实值进行比较。

训练数据通常很大且复杂,而验证数据通常较小。训练样本越多,模型性能越好。例如,在垃圾邮件检测任务中,如果训练集中有 10 封垃圾邮件和 10 封非垃圾邮件,那么机器学习模型很难检测新邮件中的垃圾邮件,因为它没有关于垃圾邮件看起来是怎样的足够信息。但是,如果我们有 1000 万封垃圾邮件和 1000 万封非垃圾邮件,那么我们的模型就更容易检测新垃圾邮件,因为它看到了很多它看起来的例子。

在本教程中,您将学习 PyTorch 中的训练和验证数据。我们还将展示训练和验证数据对机器学习模型的重要性,特别是对神经网络的关注。具体来说,您将学习:

  • PyTorch 中训练和验证数据的概念。
  • 如何在 PyTorch 中将数据分割成训练集和验证集。
  • 如何使用 PyTorch 中的内置函数构建简单的线性回归模型。
  • 如何使用各种学习率来训练模型以获得所需的准确度。
  • 如何调整超参数以获得数据的最佳模型。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

使用 PyTorch 中的优化器。
图片由 Markus Krisetya 拍摄。保留部分权利。

概述

本教程分为三个部分;它们是

  • 构建训练集和验证集的数据类
  • 构建和训练模型
  • 可视化结果

构建训练集和验证集的数据类

首先,让我们加载本教程所需的一些库。

我们将从构建自定义数据集类开始,以生成足够的合成数据。这将允许我们将数据分割成训练集和验证集。此外,我们还将添加一些步骤将异常值包含在数据中。

对于训练集,我们将 `train` 参数默认设置为 `True`。如果设置为 `False`,它将生成验证数据。我们已将训练集和验证集创建为单独的对象。

现在,让我们可视化我们的数据。您将在 x=-2 和 x=0 处看到异常值。

训练和验证数据集

生成上述图表的完整代码如下。

构建和训练模型

PyTorch 的 `nn` 包为我们提供了许多有用的函数。我们将从 `nn` 包中导入线性回归模型和损失函数。此外,我们还将从 `torch.utils.data` 包中导入 `DataLoader`。

我们将创建一个包含各种学习率的列表,以一次性训练多个模型。这是深度学习从业者中的一种常见做法,他们会调整不同的超参数来获得最佳模型。我们将把训练和验证损失都存储在张量中,并创建一个名为 `Models` 的空列表来存储我们的模型。稍后,我们将绘制图表来评估我们的模型。

为了训练模型,我们将使用各种学习率和随机梯度下降(SGD)优化器。训练和验证数据的结果将与模型一起保存在列表中。我们将所有模型训练 20 个 epoch。

上面的代码分别收集训练和验证的损失。这有助于我们了解我们的训练效果如何,例如,是否过拟合。如果我们发现验证集的损失与训练集的损失差异很大,那么它就过拟合了。在这种情况下,我们训练好的模型无法泛化到它未见过的数据,即验证集。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

可视化结果

在上面,我们使用了相同的模型(线性回归)并以固定的 epoch 数进行训练。唯一的变化是学习率。然后我们可以比较哪个学习率在最快收敛方面给我们带来了最好的模型。

让我们可视化每个学习率的训练和验证数据的损失图。通过查看图表,您可以观察到在学习率为 0.001 时损失最小,这意味着我们的模型在该学习率下对于此数据收敛得更快。

损失与学习率

我们还绘制每个模型在验证数据上的预测。一个完全收敛的模型应该能完美地拟合数据,而一个远未收敛的模型会产生偏离数据的预测。

我们看到预测可视化如下:

如您所见,绿线更接近验证数据点。它是具有最佳学习率(0.001)的线。

以下是从创建数据到可视化训练和验证损失的完整代码。

总结

在本教程中,您学习了 PyTorch 中训练和验证数据的概念。具体来说,您学习了:

  • PyTorch 中训练和验证数据的概念。
  • 如何在 PyTorch 中将数据分割成训练集和验证集。
  • 如何使用 PyTorch 中的内置函数构建简单的线性回归模型。
  • 如何使用各种学习率来训练模型以获得所需的准确度。
  • 如何调整超参数以获得数据的最佳模型。

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

对 *PyTorch 中的训练和验证数据* 的一则回复

  1. T.Hofkamp 2023年5月8日下午7:01 #

    关于训练和验证的精彩文章。但我缺失的信息如下:
    我有一些带有参数的对象。我想将参数输入 AI(PyTorch),AI 会对对象做出决策。每个对象都有一个 id。我不想将 objectid 输入 AI,因为它可能会影响结果。AI 的输出是决策的数组/流,但没有 objectid。
    有没有办法将 AI 的结果与相应的 objectid 一起获取?
    (如果 AI 在管道中,具有多个写入器和结果读取器,那将非常方便)

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。