使用检查点和提前停止管理 PyTorch 训练过程

一个大型深度学习模型可能需要很长时间才能训练。如果训练过程在中间被中断,你将会丢失大量工作。但有时,你确实想在训练过程中间中断它,因为你知道再继续下去也不会得到更好的模型。在这篇文章中,你将发现如何在 PyTorch 中控制训练循环,以便你可以恢复中断的过程,或者提前停止训练循环。

完成这篇文章后,您将了解:

  • 在训练期间对神经网络模型进行检查点的重要性
  • 如何在训练期间对模型进行检查点并在之后恢复它
  • 如何通过检查点提前终止训练循环

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

使用检查点和提前停止管理 PyTorch 训练过程
照片来源: Arron Choi。部分权利保留。

概述

本章分为两部分;它们是

  • 神经网络模型的检查点
  • 带有提前停止的检查点

神经网络模型的检查点

许多系统都有状态。如果你能保存系统的所有状态并在之后恢复它,你就可以随时回溯系统行为的特定时间点。如果你用过 Microsoft Word,并保存了文档的多个版本,因为你不确定是否要撤销编辑,这和这里的想法是一样的。

对于长时间运行的进程也是如此。应用程序检查点是一种容错技术。在这种方法中,会在系统发生故障时捕获系统的状态快照。如果出现问题,你可以从快照中恢复。检查点可以直接使用,也可以作为新运行的起点,从中断的地方继续。训练深度学习模型时,检查点捕获模型权重。这些权重可用于进行预测,或作为持续训练的基础。

PyTorch 不提供任何检查点功能,但它有检索和恢复模型权重的函数。因此,你可以使用它们来实现检查点逻辑。让我们创建一个检查点和恢复函数,它们只是简单地保存模型的权重并将它们加载回来。

下面是你通常训练 PyTorch 模型的方式。使用的数据集是从 OpenML 平台获取的。它是一个二元分类数据集。此示例中使用了 PyTorch DataLoader 来使训练循环更简洁。

如果你想在训练循环中添加检查点,你可以在外层 for 循环结束时添加,也就是进行模型与测试集验证的地方。假设如下:

你会在你的工作目录中看到许多文件被创建。这段代码将从第 7 个 epoch 开始检查点模型,例如,保存到文件 epoch-7.pth。这些文件中的每一个都是一个包含模型权重 pickle 数据的 ZIP 文件。没有什么能阻止你在内部 for 循环中进行检查点,但由于它产生的开销,不建议过于频繁地进行检查点。

作为一种容错技术,通过在训练循环之前添加几行代码,你可以从特定的 epoch 恢复训练。

也就是说,如果训练循环在中途被中断,那么最后一个检查点来自 epoch 7,设置 start_epoch = 8 即可。

请注意,如果你这样做,random_split() 函数会生成训练集和测试集,但由于随机性,可能会得到不同的分割。如果这对你来说是一个顾虑,你应该有一个一致的方法来创建数据集(例如,保存分割后的数据以便重复使用)。

有时,模型之外还有其他状态,你也可能想对它们进行检查点。一个特别的例子是优化器,在 Adam 等情况下,它有动态调整的动量。如果你重新启动了训练循环,你可能也想恢复优化器的动量。这并不难。关键是让你的 checkpoint() 函数更复杂,例如:

并相应地更改你的 resume() 函数:

这之所以可行,是因为在 PyTorch 中,torch.save()torch.load() 函数由 pickle 提供支持,因此你可以将其与 listdict 容器一起使用。

将所有内容整合在一起,完整的代码如下。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

带有提前停止的检查点

检查点不仅是为了容错。你也可以用它来保存你的最佳模型。如何定义什么是最佳模型是主观的,但考虑测试集的得分是一种明智的方法。假设你只想保留找到的最佳模型,可以如下修改训练循环:

变量 best_accuracy 用于跟踪迄今为止获得的最大准确率(acc),其百分比范围为 0 到 100。每当观察到更高的准确率时,模型就会被检查点保存到文件 best_model.pth。在整个训练循环之后,通过之前创建的 resume() 函数来恢复最佳模型。之后,您就可以对模型在未见过的数据上进行预测。请注意,如果您使用不同的指标进行检查点,例如交叉熵损失,那么更好的模型应该具有较低的交叉熵。因此,您应该跟踪获得的最低交叉熵。

你也可以在每个 epoch 无条件地检查点模型,并与最佳模型检查点一起进行,因为你可以自由创建多个检查点文件。由于上面的代码是找出最佳模型并复制它,所以你通常会看到对训练循环的进一步优化,即如果模型改进的希望渺茫,就提前停止。这就是可以节省训练时间的提前停止技术。

上面的代码在每个 epoch 结束时用测试集验证模型,并将找到的最佳模型保存在检查点文件中。提前停止的最简单策略是设置一个 $k$ epoch 的阈值。如果你在最近的 $k$ 个 epoch 中没有看到模型有任何改进,你就会中断训练循环。这可以通过以下方式实现:

上面的 early_stop_thresh 阈值设置为 5。有一个变量 best_epoch 记录了最佳模型的 epoch。如果模型长时间没有改进,外层 for 循环将被终止。

这个设计缓解了 n_epochs 这个设计参数。你现在可以将 n_epochs 设置为训练模型的 **最大** epoch 数,也就是一个比实际需要更大的数字,并确保你的训练循环通常会更早停止。这也是一种避免过拟合的策略:如果模型在测试集上的表现确实随着训练的进行而变差,那么这个提前停止的逻辑将中断训练并恢复最佳检查点。

将所有内容整合在一起,以下是带有提前停止的检查点的完整代码。

你可能会看到上面的代码产生以下输出:

它在 epoch 17 结束时停止,最佳模型来自 epoch 11。由于算法的随机性,你可能会看到结果略有不同。但可以肯定的是,即使上面将最大 epoch 数设置为 10000,训练循环也确实提前停止了。

当然,你可以设计一个更复杂的提前停止策略,例如,先运行至少 $N$ 个 epoch,然后允许在 $k$ 个 epoch 后提前停止。你有充分的自由来调整上面的代码,以获得最适合你需求的训练循环。

总结

在本章中,你了解了为长训练过程进行深度学习模型检查点的重要性。你学习了:

  • 什么是检查点以及它为什么有用
  • 如何对模型进行检查点以及如何恢复检查点
  • 使用检查点的不同策略
  • 如何使用检查点实现提前停止

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

4 条对 使用检查点和提前停止管理 PyTorch 训练过程 的回复

  1. AM 2023年9月4日 上午8:12 #

    谢谢。你知道有任何文章使用检查点来节省内存并允许额外的梯度累积步骤吗?

  2. Rodrigo 2024年10月3日 上午9:11 #

    感谢您的教程。

    我很好奇 torch.utils.checkpoint 在这方面起到了什么作用?

    一些反馈:写作

    start_epoch = 0
    if start_epoch > 0

    有点令人困惑,因为乍一看 if 条件永远不会满足。我猜意图是用户应该更改初始值 0 为他们想要开始的期望值,但这个意图并不十分明显。我建议将其设为命令行参数,或者至少添加一个注释:# 更改此值以获得所需的起始 epoch。

  3. Rodrigo 2024年10月3日 上午9:17 #

    好的,不用了,经过一些搜索,我发现 torch.utils.checkpoint 有完全不同的目的,那就是通过丢弃一些值并在需要时重新计算它们来节省 GPU 内存。如果你问我,这是一个不幸的名字选择。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。