保存和加载您的 PyTorch 模型

深度学习模型是对数据的数学抽象,其中包含大量参数。训练这些参数可能需要数小时、数天甚至数周,但之后,您可以将结果应用于新数据。这在机器学习中称为推理。了解如何将训练好的模型保存在磁盘上,并在以后加载它以用于推理非常重要。在本文中,您将了解如何将 PyTorch 模型保存到文件,并将它们重新加载以进行预测。阅读本章后,您将了解:

  • PyTorch 模型中的状态和参数是什么
  • 如何保存模型状态
  • 如何加载模型状态

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

保存和加载您的 PyTorch 模型
照片由 Joseph Chan 提供。保留部分权利。

概述

这篇文章分为三个部分;它们是

  • 构建示例模型
  • PyTorch 模型内部有什么
  • 访问模型的 state_dict

构建示例模型

让我们从一个非常简单的 PyTorch 模型开始。这是一个基于鸢尾花数据集的模型。您将使用 scikit-learn 加载数据集(其目标是整数标签 0、1 和 2),并为此多类分类问题训练一个神经网络。在此模型中,您将 log softmax 作为输出激活,因此可以将其与负对数似然损失函数结合使用。这相当于没有输出激活并结合交叉熵损失函数。

对于这样简单的模型和小型数据集,训练应该不会花费很长时间。之后,我们可以通过在测试集上进行评估来确认此模型是否有效。

例如,它会打印:

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

PyTorch 模型内部有什么

PyTorch 模型是 Python 中的一个对象。它包含一些深度学习构建块,例如各种类型的层和激活函数。它还知道如何连接它们,以便能够从您的输入张量生成输出。模型的算法在创建时是固定的,但它具有可训练的参数,这些参数需要在训练循环中进行修改,以便模型能够更准确。

您已经了解了在为训练循环设置优化器时如何获取模型参数,具体是:

model.parameters() 函数会提供一个生成器,该生成器依次引用每个层中以 PyTorch 张量形式存在的可训练参数。因此,您可以复制它们或覆盖它们,例如:

结果应该与之前完全相同,因为您通过复制参数基本上使这两个模型相同。

然而,情况并非总是如此。有些模型具有非可训练参数。一个例子是批量归一化层,它在许多卷积神经网络中很常见。它的作用是对其前一层产生的张量进行归一化,并将归一化后的张量传递给下一层。它有两个参数:均值和标准差,它们在训练循环中从您的输入数据中学习,但不能被优化器训练。因此,这些不属于 model.parameters(),但同样重要。

访问模型的 state_dict

要访问模型的所有参数,无论是否可训练,都可以从 state_dict() 函数中获取。以上模型,您可以得到:

上述模型产生以下输出:

它被称为 state_dict,因为模型的所有状态变量都在这里。它是一个 Python 内置 collections 模块的 OrderedDict 对象。PyTorch 模型的所有组件都有一个名称,其中的参数也是如此。OrderedDict 对象允许您通过匹配名称将权重正确地映射回参数。

这就是您应该如何保存和加载模型:将模型状态提取到 OrderedDict 中,进行序列化并将其保存到磁盘。对于推理,您首先创建一个模型(无需训练),然后加载状态。在 Python 中,序列化的原生格式是 pickle。

您知道它有效,因为您未训练过的模型产生了与您训练过的模型相同的结果。

确实,推荐的方式是使用 PyTorch API 来保存和加载状态,而不是手动使用 pickle。

*.pth 文件实际上是 PyTorch 创建的一些 pickle 文件的 zip 文件。它之所以被推荐,是因为 PyTorch 可以在其中存储额外的信息。请注意,您只存储了状态,而没有存储模型。您仍然需要使用 Python 代码创建模型,并将状态加载到其中。如果您希望也存储模型,您可以传递整个模型而不是状态。

但请记住,由于 Python 语言的性质,这样做并不能让您摆脱保留模型代码的麻烦。上面的 newmodel 对象是您之前定义的 Multiclass 类的实例。当您从磁盘加载模型时,Python 需要详细了解此类的定义。如果您只运行包含 torch.load() 行的脚本,您将看到以下错误消息:

这就是为什么建议仅保存 state dict 而不是整个模型。

总而言之,以下是演示如何创建模型、训练模型并将其保存到磁盘的完整代码。

以下是如何从磁盘加载模型并运行它以进行推理。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

总结

在本文中,您学习了如何将训练好的 PyTorch 模型的副本保存在磁盘上以及如何重用它。特别是,您学到了:

  • PyTorch 模型中的参数和状态是什么
  • 如何将模型的所有必要状态保存到磁盘
  • 如何从保存的状态重建一个可工作的模型

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

4 条评论关于 保存和加载您的 PyTorch 模型

  1. Javier 2023年3月10日 上午6:35 #

    谢谢,这真的很有用,只是一个小评论。在最后的示例代码块中,加载操作显示为使用 picklle 而不是 torch.load。Torch load 会比之前显示的保存代码更一致。

    • James Carmichael 2023年3月10日 上午7:59 #

      Javier,不客气!感谢您的反馈和建议!

  2. Poult 2023年8月10日 下午4:48 #

    您好,我认为在最后的代码窗口,31-32行
    应该是:“with open(“iris-model.pth”, “rb”) as fp
    model.load_state_dict(torch.load(fp))”

    • James Carmichael 2023年8月11日 上午8:15 #

      Poult,感谢您的反馈和建议!

发表回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。