为 PyTorch 模型创建训练循环

PyTorch 为深度学习模型提供了许多构建块,但训练循环不属于其中。这种灵活性允许你在训练期间做任何想做的事情,但一些基本结构在大多数用例中是通用的。

在这篇文章中,你将学习如何创建一个训练循环,为你的模型训练提供基本信息,并可以选择显示任何信息。完成本文后,你将知道

  • 训练循环的基本组成部分
  • 如何使用 tqdm 显示训练进度

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

为 PyTorch 模型创建训练循环
图片来源: pat pat。保留部分权利。

概述

这篇文章分为三个部分;它们是

  • 深度学习模型训练的要素
  • 训练期间收集统计数据
  • 使用 tqdm 报告训练进度

深度学习模型训练的要素

与所有机器学习模型一样,模型设计指定了操作输入并产生输出的算法。但在模型中,你需要微调参数才能实现这一点。这些模型参数也称为权重、偏差、核或其他名称,具体取决于特定的模型和层。训练是向模型输入样本数据,以便优化器可以微调这些参数。

当你训练模型时,通常从数据集开始。每个数据集都包含大量数据样本。当你获得数据集时,建议将其分成两部分:训练集和测试集。训练集进一步分成批次,用于训练循环以驱动梯度下降算法。然而,测试集用作基准,以衡量模型的好坏。通常,你不使用训练集作为度量标准,而是使用梯度下降算法未见过的测试集,这样你就可以判断模型是否很好地拟合了未见过的数据。

过拟合是指模型对训练集拟合得太好(即精度非常高),但在测试集上表现明显更差。欠拟合是指模型甚至无法很好地拟合训练集。当然,你希望一个好的模型既不过拟合也不欠拟合。

神经网络模型的训练以 epoch 进行。通常,一个 epoch 意味着你遍历整个训练集一次,尽管你一次只输入一个批次。在每个 epoch 结束时,通常还会执行一些内务管理任务,例如使用测试集对部分训练的模型进行基准测试、保存模型检查点、决定是否提前停止训练以及收集训练统计数据等。

在每个 epoch 中,你以批次形式将数据样本输入模型,并运行梯度下降算法。这是训练循环中的一步,因为你运行模型一次前向传播(即提供输入并捕获输出),以及一次反向传播(从输出评估损失指标并推导出每个参数的梯度,一直追溯到输入层)。反向传播使用自动微分计算梯度。然后,此梯度由梯度下降算法用于调整模型参数。一个 epoch 中有多个步骤。

重用上一教程中的示例,你可以下载数据集并按如下方式将数据集分成两部分

这个数据集很小——只有 768 个样本。这里,它将前 700 个作为训练集,其余的作为测试集。

这不是本文的重点,但你可以重用之前文章中的模型、损失函数和优化器

有了数据和模型,这就是最小的训练循环,每个步骤都有前向和反向传播

在内部 for 循环中,你获取数据集中的每个批次并评估损失。损失是一个 PyTorch 张量,它会记住它是如何得出其值的。然后你将优化器管理的所有梯度归零,并调用 loss.backward() 来运行反向传播算法。结果设置了张量 loss 直接和间接依赖的所有张量的梯度。之后,在调用 step() 时,优化器将检查它管理的每个参数并更新它们。

一切完成后,你可以使用测试集运行模型以评估其性能。评估可以基于与损失函数不同的函数。例如,此分类问题使用准确率

将所有内容整合在一起,这就是完整的代码

训练期间收集统计数据

上面的训练循环对于可以在几秒钟内完成训练的小模型来说应该很好用。但是对于更大的模型或更大的数据集,你会发现训练时间明显更长。在等待训练完成时,你可能想看看进展如何,因为如果出现任何错误,你可能想中断训练。

通常,在训练过程中,你希望看到以下内容

  • 在每个步骤中,你希望知道损失指标,并且你期望损失下降
  • 在每个步骤中,你希望知道其他指标,例如训练集上的准确性,这些指标很重要但未参与梯度下降
  • 在每个 epoch 结束时,你希望使用测试集评估部分训练的模型并报告评估指标
  • 在训练结束时,你希望能够可视化上述指标

这些都是可能的,但你需要向训练循环添加更多代码,如下所示

当你将损失和准确性收集到列表中时,你可以使用 matplotlib 绘制它们。但请注意,你是在每个步骤中收集训练集统计数据,而测试集准确性仅在 epoch 结束时收集。因此,你希望显示每个 epoch 中训练循环的平均准确性,以便它们可以相互比较。

将所有内容整合在一起,下面是完整的代码

故事并未到此结束。实际上,你可以向训练循环添加更多代码,尤其是在处理更复杂的模型时。一个例子是检查点。你可能希望保存模型(例如,使用 pickle),这样,如果出于任何原因程序停止,你可以从中间重新开始训练循环。另一个例子是早期停止,它允许你监控每个 epoch 结束时使用测试集获得的准确性,如果模型在一段时间内没有改进,则中断训练。这是因为在给定模型设计的情况下,你可能无法进一步改进,并且你不希望过拟合。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

使用 tqdm 报告训练进度

如果你运行上面的代码,你会发现训练循环运行时屏幕上会打印很多行。你的屏幕可能会变得混乱。你可能还希望看到一个动画进度条,以便更好地告诉你训练进度。库 tqdm 是创建进度条的流行工具。将上面的代码转换为使用 tqdm 再简单不过了

tqdm 的用法是使用 trange() 创建一个迭代器,就像 Python 的 range() 函数一样,你可以在循环中读取数字。你可以通过更新其描述或“后缀”数据来访问进度条,但你必须在其内容耗尽之前完成此操作。set_postfix() 函数功能强大,因为它可以显示任何内容。

事实上,除了 trange() 之外还有一个 tqdm() 函数,它可以迭代现有列表。你可能会觉得它更容易使用,你可以将上面的循环重写如下

以下是完整的代码(不包括 matplotlib 绘图)

总结

在这篇文章中,你详细了解了如何正确设置 PyTorch 模型的训练循环。具体来说,你看到了

  • 训练循环中需要实现的元素
  • 训练循环如何将训练数据连接到梯度下降优化器
  • 如何收集训练循环中的信息并显示它们

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

暂无评论。

发表评论

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。