可视化 PyTorch 模型

PyTorch 是一个深度学习库。您可以使用 PyTorch 构建非常复杂的深度学习模型。然而,有时您需要模型的图形化架构表示。在这篇文章中,您将学习

  • 如何以交换格式保存您的 PyTorch 模型
  • 如何使用 Netron 创建图形化表示。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

可视化 PyTorch 模型
图片由 Ken Cheung 提供。保留部分权利。

概述

这篇文章分为两部分:

  • 为什么 PyTorch 模型的图形化表示很困难
  • 如何使用 Netron 创建模型图

为什么 PyTorch 模型的图形化表示很困难

PyTorch 是一个非常灵活的深度学习库。严格来说,它从不强制您应该如何构建模型,只要它像一个函数一样,可以将输入张量转换为输出张量即可。这是一个问题:对于一个模型,您永远不知道它是如何工作的,除非您跟踪输入张量并收集轨迹,直到获得输出张量。因此,将 PyTorch 模型转换为图片并非易事。

有多个库可以解决这个问题。但总的来说,只有两种方法可以解决:您可以跟踪前向传播中的张量,查看应用了哪些操作(即层),或者跟踪反向传播中的张量,查看梯度如何传播到输入。您只能通过这种方式找到模型内部结构的线索。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

如何使用 Netron 创建模型图

当您保存 PyTorch 模型时,您保存的是它的状态。您可以使用 model.state_dict() 获取模型状态。虽然权重张量有名称,这有助于您将它们恢复到模型中,但您没有关于权重如何相互连接的线索。连接张量并找出它们之间关系的唯一方法是获取张量梯度:当您运行模型并获得输出时,所涉及的计算,包括对其他张量的依赖,都会被每个中间张量记住,以便可以执行自动微分。

事实上,如果您想了解 PyTorch 模型背后的算法,这也是可行的方法。从 PyTorch 模型创建图形的工具很少。下面,您将了解 Netron 工具。它是一个“深度学习模型查看器”。它是一个可以安装并在 macOS、Linux 和 Windows 上运行的软件。您可以访问下面的页面并下载适用于您平台的软件

还有一个在线版本可用,您可以通过上传模型文件来查看您的模型。

Netron 无法从保存的状态可视化 PyTorch 模型,因为没有足够的线索来解释模型的结构。然而,PyTorch 允许您将模型转换为 Netron 可以理解的交换格式 ONNX。

让我们从一个例子开始。下面您创建了一个简单的模型来对鸢尾花数据集进行分类。这是一个有三个类别的分类问题。因此,模型应该输出一个包含三个值的向量。解决此问题的完整代码如下,数据集来自 scikit-learn

运行以上代码会产生以下输出,例如

因此您知道 model 是一个可以接受张量并返回张量的 PyTorch 模型。您可以使用 torch.onnx.export() 函数将此模型转换为 ONNX 格式

运行此命令将在本地目录中创建一个名为 iris.onnx 的文件。您需要提供一个与模型兼容的示例张量作为输入(上例中的 X_test)。这是因为在转换过程中,它需要遵循此示例张量来理解应该应用哪些操作,以便您可以将算法一步步转换为 ONNX 格式。PyTorch 模型中的每个权重都是一个张量,并分配有一个名称。但输入和输出张量通常没有命名,因此您在运行 export() 时需要为它们提供名称。这些名称应作为字符串列表提供,因为通常情况下,一个模型可以接受多个张量并返回多个张量。

通常您应该在训练循环之后运行 export()。这是因为创建的 ONNX 模型包含一个完整的模型,您可以在没有 PyTorch 库的情况下运行它。您希望将优化后的权重保存到其中。然而,为了在 Netron 中可视化模型,模型的质量不是问题。您可以在 PyTorch 模型创建后立即运行 export()

启动 Netron 后,您可以打开保存的 ONNX 文件。在此示例中,您应该看到以下屏幕

它展示了输入张量如何通过深度学习模型中的不同操作连接到输出张量。您提供给 export() 函数的输入和输出张量名称用于可视化。单击一个框将为您提供有关特定张量或操作的更多详细信息。然而,您在 Netron 中看到的操作名称可能与您在 PyTorch 中称呼它们的方式不同。在上面的屏幕中,您看到 nn.Linear() 层变为“Gemm”,它代表“通用矩阵乘法”操作。您甚至可以通过几次点击在 Netron 中检查图层上的权重。

如果您想保留此可视化的副本,可以在 Netron 中将其导出为 PNG 格式。

进一步阅读

Netron 是一个开源项目,您可以在 Github 上找到其源代码

Netron 的在线版本可在以下链接找到

另一个可视化库是 torchviz,但与您上面看到的示例不同,它从反向传播中追踪模型

总结

在这篇文章中,您学习了如何可视化模型。特别是,您学习了

  • 为什么可视化 PyTorch 模型很困难
  • 如何将 PyTorch 模型转换为 ONNX 格式
  • 如何使用 Netron 可视化 ONNX 模型

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

可视化 PyTorch 模型 的 2 条回复

  1. Trương Quốc Quân 2023 年 3 月 10 日 下午 1:50 #

    太棒了!这适用于 TensorFlow 和 Keras 吗?我一直在使用 StellarGraph,一个基于 TensorFlow 构建的图神经网络库。Tensorflow 也支持将模型导出到 ONNX。

发表评论

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。