一个大型深度学习模型可能需要很长时间才能训练。如果训练过程在中间被中断,你将会丢失大量工作。但有时,你确实想在训练过程中间中断它,因为你知道再继续下去也不会得到更好的模型。在这篇文章中,你将发现如何在 PyTorch 中控制训练循环,以便你可以恢复中断的过程,或者提前停止训练循环。
完成这篇文章后,您将了解:
- 在训练期间对神经网络模型进行检查点的重要性
- 如何在训练期间对模型进行检查点并在之后恢复它
- 如何通过检查点提前终止训练循环
通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码的自学教程。
让我们开始吧。

使用检查点和提前停止管理 PyTorch 训练过程
照片来源: Arron Choi。部分权利保留。
概述
本章分为两部分;它们是
- 神经网络模型的检查点
- 带有提前停止的检查点
神经网络模型的检查点
许多系统都有状态。如果你能保存系统的所有状态并在之后恢复它,你就可以随时回溯系统行为的特定时间点。如果你用过 Microsoft Word,并保存了文档的多个版本,因为你不确定是否要撤销编辑,这和这里的想法是一样的。
对于长时间运行的进程也是如此。应用程序检查点是一种容错技术。在这种方法中,会在系统发生故障时捕获系统的状态快照。如果出现问题,你可以从快照中恢复。检查点可以直接使用,也可以作为新运行的起点,从中断的地方继续。训练深度学习模型时,检查点捕获模型权重。这些权重可用于进行预测,或作为持续训练的基础。
PyTorch 不提供任何检查点功能,但它有检索和恢复模型权重的函数。因此,你可以使用它们来实现检查点逻辑。让我们创建一个检查点和恢复函数,它们只是简单地保存模型的权重并将它们加载回来。
1 2 3 4 5 6 7 |
import torch def checkpoint(model, filename): torch.save(model.state_dict(), filename) def resume(model, filename): model.load_state_dict(torch.load(filename)) |
下面是你通常训练 PyTorch 模型的方式。使用的数据集是从 OpenML 平台获取的。它是一个二元分类数据集。此示例中使用了 PyTorch DataLoader 来使训练循环更简洁。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader, random_split, default_collate from sklearn.datasets import fetch_openml from sklearn.preprocessing import LabelEncoder data = fetch_openml("electricity", version=1, parser="auto") # 对目标进行标签编码,转换为浮点张量 X = data['data'].astype('float').values y = data['target'] 编码器 = LabelEncoder() 编码器。fit(y) y = encoder.transform(y) X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 用于模型评估的训练-测试拆分 trainset, testset = random_split(TensorDataset(X, y), [0.7, 0.3]) # 定义模型 model = nn.Sequential( nn.Linear(8, 12), nn.ReLU(), nn.Linear(12, 12), nn.ReLU(), nn.Linear(12, 1), nn.Sigmoid(), ) # 训练模型 n_epochs = 100 loader = DataLoader(trainset, shuffle=True, batch_size=32) X_test, y_test = default_collate(testset) loss_fn = nn.BCELoss() optimizer = optim.SGD(model.parameters(), lr=0.1) for epoch in range(n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() print(f"Epoch {epoch} 结束: accuracy = {float(acc)*100:.2f}%") |
如果你想在训练循环中添加检查点,你可以在外层 for 循环结束时添加,也就是进行模型与测试集验证的地方。假设如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
... for epoch in range(n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() print(f"Epoch {epoch} 结束: accuracy = {float(acc)*100:.2f}%") checkpoint(model, f"epoch-{epoch}.pth") |
你会在你的工作目录中看到许多文件被创建。这段代码将从第 7 个 epoch 开始检查点模型,例如,保存到文件 epoch-7.pth
。这些文件中的每一个都是一个包含模型权重 pickle 数据的 ZIP 文件。没有什么能阻止你在内部 for 循环中进行检查点,但由于它产生的开销,不建议过于频繁地进行检查点。
作为一种容错技术,通过在训练循环之前添加几行代码,你可以从特定的 epoch 恢复训练。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
start_epoch = 0 if start_epoch > 0: resume_epoch = start_epoch - 1 resume(model, f"epoch-{resume_epoch}.pth") for epoch in range(start_epoch, n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() print(f"Epoch {epoch} 结束: accuracy = {float(acc)*100:.2f}%") checkpoint(model, f"epoch-{epoch}.pth") |
也就是说,如果训练循环在中途被中断,那么最后一个检查点来自 epoch 7,设置 start_epoch = 8
即可。
请注意,如果你这样做,random_split()
函数会生成训练集和测试集,但由于随机性,可能会得到不同的分割。如果这对你来说是一个顾虑,你应该有一个一致的方法来创建数据集(例如,保存分割后的数据以便重复使用)。
有时,模型之外还有其他状态,你也可能想对它们进行检查点。一个特别的例子是优化器,在 Adam 等情况下,它有动态调整的动量。如果你重新启动了训练循环,你可能也想恢复优化器的动量。这并不难。关键是让你的 checkpoint()
函数更复杂,例如:
1 2 3 4 |
torch.save({ 'optimizer': optimizer.state_dict(), 'model': model.state_dict(), }, filename) |
并相应地更改你的 resume()
函数:
1 2 3 |
checkpoint = torch.load(filename) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) |
这之所以可行,是因为在 PyTorch 中,torch.save()
和 torch.load()
函数由 pickle
提供支持,因此你可以将其与 list
或 dict
容器一起使用。
将所有内容整合在一起,完整的代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader, random_split, default_collate from sklearn.datasets import fetch_openml from sklearn.preprocessing import LabelEncoder data = fetch_openml("electricity", version=1, parser="auto") # 对目标进行标签编码,转换为浮点张量 X = data['data'].astype('float').values y = data['target'] 编码器 = LabelEncoder() 编码器。fit(y) y = encoder.transform(y) X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 用于模型评估的训练-测试拆分 trainset, testset = random_split(TensorDataset(X, y), [0.7, 0.3]) def checkpoint(model, filename): torch.save(model.state_dict(), filename) def resume(model, filename): model.load_state_dict(torch.load(filename)) # 定义模型 model = nn.Sequential( nn.Linear(8, 12), nn.ReLU(), nn.Linear(12, 12), nn.ReLU(), nn.Linear(12, 1), nn.Sigmoid(), ) # 训练模型 n_epochs = 100 start_epoch = 0 loader = DataLoader(trainset, shuffle=True, batch_size=32) X_test, y_test = default_collate(testset) loss_fn = nn.BCELoss() optimizer = optim.SGD(model.parameters(), lr=0.1) if start_epoch > 0: resume_epoch = start_epoch - 1 resume(model, f"epoch-{resume_epoch}.pth") for epoch in range(start_epoch, n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() print(f"Epoch {epoch} 结束: accuracy = {float(acc)*100:.2f}%") checkpoint(model, f"epoch-{epoch}.pth") |
想开始使用PyTorch进行深度学习吗?
立即参加我的免费电子邮件速成课程(附示例代码)。
点击注册,同时获得该课程的免费PDF电子书版本。
带有提前停止的检查点
检查点不仅是为了容错。你也可以用它来保存你的最佳模型。如何定义什么是最佳模型是主观的,但考虑测试集的得分是一种明智的方法。假设你只想保留找到的最佳模型,可以如下修改训练循环:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
... best_accuracy = -1 for epoch in range(start_epoch, n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() acc = float(acc) * 100 print(f"Epoch {epoch} 结束: accuracy = {acc:.2f}%") if acc > best_accuracy: best_accuracy = acc checkpoint(model, "best_model.pth") resume(model, "best_model.pth") |
变量 best_accuracy
用于跟踪迄今为止获得的最大准确率(acc
),其百分比范围为 0 到 100。每当观察到更高的准确率时,模型就会被检查点保存到文件 best_model.pth
。在整个训练循环之后,通过之前创建的 resume()
函数来恢复最佳模型。之后,您就可以对模型在未见过的数据上进行预测。请注意,如果您使用不同的指标进行检查点,例如交叉熵损失,那么更好的模型应该具有较低的交叉熵。因此,您应该跟踪获得的最低交叉熵。
你也可以在每个 epoch 无条件地检查点模型,并与最佳模型检查点一起进行,因为你可以自由创建多个检查点文件。由于上面的代码是找出最佳模型并复制它,所以你通常会看到对训练循环的进一步优化,即如果模型改进的希望渺茫,就提前停止。这就是可以节省训练时间的提前停止技术。
上面的代码在每个 epoch 结束时用测试集验证模型,并将找到的最佳模型保存在检查点文件中。提前停止的最简单策略是设置一个 $k$ epoch 的阈值。如果你在最近的 $k$ 个 epoch 中没有看到模型有任何改进,你就会中断训练循环。这可以通过以下方式实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
early_stop_thresh = 5 best_accuracy = -1 best_epoch = -1 for epoch in range(start_epoch, n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() acc = float(acc) * 100 print(f"Epoch {epoch} 结束: accuracy = {acc:.2f}%") if acc > best_accuracy: best_accuracy = acc best_epoch = epoch checkpoint(model, "best_model.pth") elif epoch - best_epoch > early_stop_thresh: print("提前在 epoch %d 停止训练" % epoch) break # 终止训练循环 resume(model, "best_model.pth") |
上面的 early_stop_thresh
阈值设置为 5。有一个变量 best_epoch
记录了最佳模型的 epoch。如果模型长时间没有改进,外层 for 循环将被终止。
这个设计缓解了 n_epochs
这个设计参数。你现在可以将 n_epochs
设置为训练模型的 **最大** epoch 数,也就是一个比实际需要更大的数字,并确保你的训练循环通常会更早停止。这也是一种避免过拟合的策略:如果模型在测试集上的表现确实随着训练的进行而变差,那么这个提前停止的逻辑将中断训练并恢复最佳检查点。
将所有内容整合在一起,以下是带有提前停止的检查点的完整代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader, random_split, default_collate from sklearn.datasets import fetch_openml from sklearn.preprocessing import LabelEncoder data = fetch_openml("electricity", version=1, parser="auto") # 对目标进行标签编码,转换为浮点张量 X = data['data'].astype('float').values y = data['target'] 编码器 = LabelEncoder() 编码器。fit(y) y = encoder.transform(y) X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 用于模型评估的训练-测试拆分 trainset, testset = random_split(TensorDataset(X, y), [0.7, 0.3]) def checkpoint(model, filename): torch.save(model.state_dict(), filename) def resume(model, filename): model.load_state_dict(torch.load(filename)) # 定义模型 model = nn.Sequential( nn.Linear(8, 12), nn.ReLU(), nn.Linear(12, 12), nn.ReLU(), nn.Linear(12, 1), nn.Sigmoid(), ) # 训练模型 n_epochs = 10000 # 比我们需要的还多 loader = DataLoader(trainset, shuffle=True, batch_size=32) X_test, y_test = default_collate(testset) loss_fn = nn.BCELoss() optimizer = optim.SGD(model.parameters(), lr=0.1) early_stop_thresh = 5 best_accuracy = -1 best_epoch = -1 for epoch in range(n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() y_pred = model(X_test) acc = (y_pred.round() == y_test).float().mean() acc = float(acc) * 100 print(f"Epoch {epoch} 结束: accuracy = {acc:.2f}%") if acc > best_accuracy: best_accuracy = acc best_epoch = epoch checkpoint(model, "best_model.pth") elif epoch - best_epoch > early_stop_thresh: print("提前在 epoch %d 停止训练" % epoch) break # 终止训练循环 resume(model, "best_model.pth") |
你可能会看到上面的代码产生以下输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
Epoch 0 结束: accuracy = 61.84% Epoch 1 结束: accuracy = 55.90% Epoch 2 结束: accuracy = 63.95% Epoch 3 结束: accuracy = 66.87% Epoch 4 结束: accuracy = 64.77% Epoch 5 结束: accuracy = 60.03% Epoch 6 结束: accuracy = 67.16% Epoch 7 结束: accuracy = 66.01% Epoch 8 结束: accuracy = 62.88% Epoch 9 结束: accuracy = 64.28% Epoch 10 结束: accuracy = 68.63% Epoch 11 结束: accuracy = 70.56% Epoch 12 结束: accuracy = 64.62% Epoch 13 结束: accuracy = 65.63% Epoch 14 结束: accuracy = 66.81% Epoch 15 结束: accuracy = 65.11% Epoch 16 结束: accuracy = 55.81% Epoch 17 结束: accuracy = 54.59% 提前在 epoch 17 停止训练 |
它在 epoch 17 结束时停止,最佳模型来自 epoch 11。由于算法的随机性,你可能会看到结果略有不同。但可以肯定的是,即使上面将最大 epoch 数设置为 10000,训练循环也确实提前停止了。
当然,你可以设计一个更复杂的提前停止策略,例如,先运行至少 $N$ 个 epoch,然后允许在 $k$ 个 epoch 后提前停止。你有充分的自由来调整上面的代码,以获得最适合你需求的训练循环。
总结
在本章中,你了解了为长训练过程进行深度学习模型检查点的重要性。你学习了:
- 什么是检查点以及它为什么有用
- 如何对模型进行检查点以及如何恢复检查点
- 使用检查点的不同策略
- 如何使用检查点实现提前停止
谢谢。你知道有任何文章使用检查点来节省内存并允许额外的梯度累积步骤吗?
你好 AM……以下讨论可能对你有帮助
https://ai.stackexchange.com/questions/31675/what-is-better-to-use-early-stopping-model-checkpoint-or-both
感谢您的教程。
我很好奇
torch.utils.checkpoint
在这方面起到了什么作用?一些反馈:写作
start_epoch = 0
if start_epoch > 0
有点令人困惑,因为乍一看 if 条件永远不会满足。我猜意图是用户应该更改初始值 0 为他们想要开始的期望值,但这个意图并不十分明显。我建议将其设为命令行参数,或者至少添加一个注释:# 更改此值以获得所需的起始 epoch。
好的,不用了,经过一些搜索,我发现
torch.utils.checkpoint
有完全不同的目的,那就是通过丢弃一些值并在需要时重新计算它们来节省 GPU 内存。如果你问我,这是一个不幸的名字选择。