深度学习模型是对数据的数学抽象,其中包含大量参数。训练这些参数可能需要数小时、数天甚至数周,但之后,您可以将结果应用于新数据。这在机器学习中称为推理。了解如何将训练好的模型保存在磁盘上,并在以后加载它以用于推理非常重要。在本文中,您将了解如何将 PyTorch 模型保存到文件,并将它们重新加载以进行预测。阅读本章后,您将了解:
- PyTorch 模型中的状态和参数是什么
- 如何保存模型状态
- 如何加载模型状态
通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码的自学教程。
让我们开始吧。

保存和加载您的 PyTorch 模型
照片由 Joseph Chan 提供。保留部分权利。
概述
这篇文章分为三个部分;它们是
- 构建示例模型
- PyTorch 模型内部有什么
- 访问模型的
state_dict
构建示例模型
让我们从一个非常简单的 PyTorch 模型开始。这是一个基于鸢尾花数据集的模型。您将使用 scikit-learn 加载数据集(其目标是整数标签 0、1 和 2),并为此多类分类问题训练一个神经网络。在此模型中,您将 log softmax 作为输出激活,因此可以将其与负对数似然损失函数结合使用。这相当于没有输出激活并结合交叉熵损失函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # 将数据加载到 NumPy 数组中 data = load_iris() X, y = data["data"], data["target"] # 将 NumPy 数组转换为 PyTorch 张量 X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.long) # 划分 X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True) # PyTorch 模型 class Multiclass(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(4, 8) self.act = nn.ReLU() self.output = nn.Linear(8, 3) self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, x): x = self.act(self.hidden(x)) x = self.logsoftmax(self.output(x)) return x model = Multiclass() # 损失指标和优化器 loss_fn = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 准备模型和训练参数 n_epochs = 100 batch_size = 5 batch_start = torch.arange(0, len(X), batch_size) # 训练循环 for epoch in range(n_epochs): for start in batch_start: # 获取一个批次 X_batch = X_train[start:start+batch_size] y_batch = y_train[start:start+batch_size] # 前向传播 y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) # 反向传播 optimizer.zero_grad() loss.backward() # 更新权重 optimizer.step() |
对于这样简单的模型和小型数据集,训练应该不会花费很长时间。之后,我们可以通过在测试集上进行评估来确认此模型是否有效。
1 2 3 4 |
... y_pred = model(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
例如,它会打印:
1 |
Accuracy: 0.96 |
想开始使用PyTorch进行深度学习吗?
立即参加我的免费电子邮件速成课程(附示例代码)。
点击注册,同时获得该课程的免费PDF电子书版本。
PyTorch 模型内部有什么
PyTorch 模型是 Python 中的一个对象。它包含一些深度学习构建块,例如各种类型的层和激活函数。它还知道如何连接它们,以便能够从您的输入张量生成输出。模型的算法在创建时是固定的,但它具有可训练的参数,这些参数需要在训练循环中进行修改,以便模型能够更准确。
您已经了解了在为训练循环设置优化器时如何获取模型参数,具体是:
1 |
optimizer = optim.Adam(model.parameters(), lr=0.001) |
model.parameters()
函数会提供一个生成器,该生成器依次引用每个层中以 PyTorch 张量形式存在的可训练参数。因此,您可以复制它们或覆盖它们,例如:
1 2 3 4 5 6 7 8 9 10 |
# 创建一个新模型 newmodel = Multiclass() # 要求 PyTorch 在更新时忽略 autograd 并覆盖参数 with torch.no_grad(): for newtensor, oldtensor in zip(newmodel.parameters(), model.parameters()): newtensor.copy_(oldtensor) # 使用复制的张量进行新模型测试 y_pred = newmodel(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
结果应该与之前完全相同,因为您通过复制参数基本上使这两个模型相同。
然而,情况并非总是如此。有些模型具有非可训练参数。一个例子是批量归一化层,它在许多卷积神经网络中很常见。它的作用是对其前一层产生的张量进行归一化,并将归一化后的张量传递给下一层。它有两个参数:均值和标准差,它们在训练循环中从您的输入数据中学习,但不能被优化器训练。因此,这些不属于 model.parameters()
,但同样重要。
访问模型的 state_dict
要访问模型的所有参数,无论是否可训练,都可以从 state_dict()
函数中获取。以上模型,您可以得到:
1 2 3 |
import pprint pp = pprint.PrettyPrinter(indent=4) pp.pprint(model.state_dict()) |
上述模型产生以下输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
OrderedDict([ ( 'hidden.weight', tensor([[ 0.1480, 0.0336, 0.3425, 0.2832], [ 0.5265, 0.8587, -0.7023, -1.1149], [ 0.1620, 0.8440, -0.6189, -0.6513], [-0.1559, 0.0393, -0.4701, 0.0825], [ 0.6364, -0.6622, 1.1150, 0.9162], [ 0.2081, -0.0958, -0.2601, -0.3148], [-0.0804, 0.1027, 0.7363, 0.6068], [-0.4101, -0.3774, -0.1852, 0.1524]])), ( 'hidden.bias', tensor([ 0.2057, 0.7998, -0.0578, 0.1041, -0.3903, -0.4521, -0.5307, -0.1532])), ( 'output.weight', tensor([[-0.0954, 0.8683, 1.0667, 0.2382, -0.4245, -0.0409, -0.2587, -0.0745], [-0.0829, 0.8642, -1.6892, -0.0188, 0.0420, -0.1020, 0.0344, -0.1210], [-0.0176, -1.2809, -0.3040, 0.1985, 0.2423, 0.3333, 0.4523, -0.1928]])), ('output.bias', tensor([ 0.0998, 0.6360, -0.2990]))]) |
它被称为 state_dict
,因为模型的所有状态变量都在这里。它是一个 Python 内置 collections
模块的 OrderedDict
对象。PyTorch 模型的所有组件都有一个名称,其中的参数也是如此。OrderedDict
对象允许您通过匹配名称将权重正确地映射回参数。
这就是您应该如何保存和加载模型:将模型状态提取到 OrderedDict
中,进行序列化并将其保存到磁盘。对于推理,您首先创建一个模型(无需训练),然后加载状态。在 Python 中,序列化的原生格式是 pickle。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import pickle # 保存模型 with open("iris-model.pickle", "wb") as fp: pickle.dump(model.state_dict(), fp) # 创建新模型并加载状态 newmodel = Multiclass() with open("iris-model.pickle", "rb") as fp: newmodel.load_state_dict(pickle.load(fp)) # 使用复制的张量进行新模型测试 y_pred = newmodel(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
您知道它有效,因为您未训练过的模型产生了与您训练过的模型相同的结果。
确实,推荐的方式是使用 PyTorch API 来保存和加载状态,而不是手动使用 pickle。
1 2 3 4 5 6 7 8 9 10 11 |
# 保存模型 torch.save(model.state_dict(), "iris-model.pth") # 创建新模型并加载状态 newmodel = Multiclass() newmodel.load_state_dict(torch.load("iris-model.pth")) # 使用复制的张量进行新模型测试 y_pred = newmodel(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
*.pth
文件实际上是 PyTorch 创建的一些 pickle 文件的 zip 文件。它之所以被推荐,是因为 PyTorch 可以在其中存储额外的信息。请注意,您只存储了状态,而没有存储模型。您仍然需要使用 Python 代码创建模型,并将状态加载到其中。如果您希望也存储模型,您可以传递整个模型而不是状态。
1 2 3 4 5 6 7 8 9 10 |
# 保存模型 torch.save(model, "iris-model-full.pth") # 加载模型 newmodel = torch.load("iris-model-full.pth") # 使用复制的张量进行新模型测试 y_pred = newmodel(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
但请记住,由于 Python 语言的性质,这样做并不能让您摆脱保留模型代码的麻烦。上面的 newmodel
对象是您之前定义的 Multiclass
类的实例。当您从磁盘加载模型时,Python 需要详细了解此类的定义。如果您只运行包含 torch.load()
行的脚本,您将看到以下错误消息:
1 2 3 4 5 6 7 8 9 |
Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../torch/serialization.py", line 789, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) File "/.../torch/serialization.py", line 1131, in _load result = unpickler.load() File "/.../torch/serialization.py", line 1124, in find_class return super().find_class(mod_name, name) AttributeError: Can't get attribute 'Multiclass' on <module '__main__' (built-in)> |
这就是为什么建议仅保存 state dict 而不是整个模型。
总而言之,以下是演示如何创建模型、训练模型并将其保存到磁盘的完整代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # 将数据加载到 NumPy 数组中 data = load_iris() X, y = data["data"], data["target"] # 将 NumPy 数组转换为 PyTorch 张量 X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.long) # 划分 X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True) # PyTorch 模型 class Multiclass(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(4, 8) self.act = nn.ReLU() self.output = nn.Linear(8, 3) self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, x): x = self.act(self.hidden(x)) x = self.logsoftmax(self.output(x)) return x model = Multiclass() # 损失指标和优化器 loss_fn = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 准备模型和训练参数 n_epochs = 100 batch_size = 5 batch_start = torch.arange(0, len(X), batch_size) # 训练循环 for epoch in range(n_epochs): for start in batch_start: # 获取一个批次 X_batch = X_train[start:start+batch_size] y_batch = y_train[start:start+batch_size] # 前向传播 y_pred = model(X_batch) loss = loss_fn(y_pred, y_batch) # 反向传播 optimizer.zero_grad() loss.backward() # 更新权重 optimizer.step() # 保存模型 torch.save(model.state_dict(), "iris-model.pth") |
以下是如何从磁盘加载模型并运行它以进行推理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # 将数据加载到 NumPy 数组中 data = load_iris() X, y = data["data"], data["target"] # 将 NumPy 数组转换为 PyTorch 张量 X = torch.tensor(X, dtype=torch.float32) y = torch.tensor(y, dtype=torch.long) # PyTorch 模型 class Multiclass(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(4, 8) self.act = nn.ReLU() self.output = nn.Linear(8, 3) self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, x): x = self.act(self.hidden(x)) x = self.logsoftmax(self.output(x)) return x # 创建新模型并加载状态 model = Multiclass() with open("iris-model.pickle", "rb") as fp: model.load_state_dict(pickle.load(fp)) # 运行模型进行推理 y_pred = model(X_test) acc = (torch.argmax(y_pred, 1) == y_test).float().mean() print("Accuracy: %.2f" % acc) |
进一步阅读
如果您想深入了解,本节提供了更多关于该主题的资源。
总结
在本文中,您学习了如何将训练好的 PyTorch 模型的副本保存在磁盘上以及如何重用它。特别是,您学到了:
- PyTorch 模型中的参数和状态是什么
- 如何将模型的所有必要状态保存到磁盘
- 如何从保存的状态重建一个可工作的模型
谢谢,这真的很有用,只是一个小评论。在最后的示例代码块中,加载操作显示为使用 picklle 而不是 torch.load。Torch load 会比之前显示的保存代码更一致。
Javier,不客气!感谢您的反馈和建议!
您好,我认为在最后的代码窗口,31-32行
应该是:“with open(“iris-model.pth”, “rb”) as fp
model.load_state_dict(torch.load(fp))”
Poult,感谢您的反馈和建议!