使用检查工具开发Python程序

Python 是一种解释型语言。这意味着有一个解释器来运行我们的程序,而不是编译代码并本地运行。在 Python 中,REPL(读取-评估-打印循环)可以逐行运行命令。结合 Python 提供的一些检查工具,有助于开发代码。

接下来,您将看到如何利用 Python 解释器来检查对象和开发程序。

完成本教程后,你将学到:

  • 如何在 Python 解释器中工作
  • 如何在 Python 中使用检查函数
  • 如何在检查函数的帮助下逐步开发解决方案

通过我的新书 Python for Machine Learning 启动您的项目,其中包括分步教程以及所有示例的Python 源代码文件。


让我们开始吧!

使用检查工具开发 Python 程序。
图片由 Tekton 拍摄。部分权利保留。

教程概述

本教程分为四个部分;它们是:

  • PyTorch 和 TensorFlow
  • 寻找线索
  • 从权重中学习
  • 制作一个复制器

PyTorch 和 TensorFlow

PyTorch 和 TensorFlow 是 Python 中两个最大的神经网络库。它们的代码不同,但它们能做的事情是相似的。

考虑经典的 MNIST 手写数字识别问题,您可以构建一个 LeNet-5 模型来分类数字,如下所示:

这是一个不需要任何验证或测试的简化代码。 TensorFlow 中的对应代码如下:

运行此程序将从 PyTorch 代码中获得文件 lenet5.pt,从 TensorFlow 代码中获得文件 lenet5.h5

寻找线索

如果您理解上述神经网络的作用,您应该可以指出,除了每个层中的许多乘法和加法计算之外,别无他物。在数学上,输入和每个全连接层的之间有一个矩阵乘法,然后在结果中加上偏置。在卷积层中,将核与输入矩阵的一部分进行逐元素乘法,然后将结果相加,并将偏置添加为一个特征映射的输出元素。

想开始学习机器学习 Python 吗?

立即参加我为期7天的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

在使用两种不同的框架开发相同的 LeNet-5 模型时,如果它们的权重相同,就应该能够使它们工作得完全一样。鉴于它们的架构相同,如何将一个模型的权重复制到另一个模型?

您可以按如下方式加载已保存的模型:

这可能不会告诉您太多信息。但是,如果您在命令行中运行 python 而不带任何参数,您将启动 REPL,在其中您可以输入上述代码(您可以使用 quit() 退出 REPL)。

以上不会打印任何内容。但是您可以使用内置的 type() 命令检查加载的两个模型。

因此,您在这里知道它们分别是 PyTorch 和 Keras 的神经网络模型。由于它们是经过训练的模型,权重必须存储在其中。那么,如何才能在这些模型中找到权重呢?由于它们是对象,最简单的方法是使用内置函数 dir() 来检查它们的成员。

每个对象都有很多成员。有些是属性,有些是类的成员方法。按照惯例,以下划线开头的成员是内部成员,在正常情况下不应访问。如果想查看每个成员的更多信息,可以使用 inspect 模块中的 getmembers() 函数。

getmembers() 函数的输出是一个元组列表,其中每个元组是成员的名称和成员本身。例如,从上面可以看出,__call__ 是一个“绑定方法”,即类的成员方法。

仔细查看成员名称,您会发现 PyTorch 模型中的“state”应该是您感兴趣的,而在 Keras 模型中,您会发现一个名为“weights”的成员。为了精简它们的名称,您可以在解释器中执行以下操作:

这可能需要一些试验和错误。但这并不太难,您可能会发现可以使用 PyTorch 模型中的 state_dict 查看权重。

对于 TensorFlow/Keras 模型,您可以使用 get_weights() 找到权重。

这里也有 weights 属性。

在这里,您可以观察到以下几点:在 PyTorch 模型中,state_dict() 函数会给出一个 OrderedDict,它是一个具有指定顺序键的字典。有诸如 0.weight 之类的键,它们映射到张量值。在 Keras 模型中,get_weights() 函数返回一个列表。列表中的每个元素都是一个 NumPy 数组。weight 属性也保存一个列表,但元素是 tf.Variable 类型。

您可以通过检查每个张量或数组的形状来了解更多信息。

虽然您在上面的 Keras 模型中没有看到层名称,但实际上,您可以使用类似的推理来查找层并获取它们的名称。

从权重中学习

通过比较 PyTorch 模型 state_dict() 的结果和 Keras 模型 get_weights() 的结果,您可以看到它们都包含 10 个元素。从 PyTorch 张量和 NumPy 数组的形状来看,您还可以注意到它们具有相似的形状。这可能是因为两个框架都按照从输入到输出的顺序识别模型。您还可以通过比较 state_dict() 输出的键与 Keras 模型的层名称来进一步确认。

您可以检查如何通过提取一个 PyTorch 张量并检查它来操作它。

从 PyTorch 张量上 dir() 的输出中,您找到了一个名为 numpy 的成员,通过调用该函数,它似乎可以将张量转换为 NumPy 数组。您对此可以相当确定,因为您看到数字匹配并且形状也匹配。事实上,通过查看文档,您可以更加确定。

help() 函数将显示函数的文档字符串,这通常是它的文档。

由于这是第一个卷积层的核,通过比较该核的形状与 Keras 模型的形状,您可能会注意到它们的形状不同。

请注意,第一层的输入是一个 28×28×1 的图像数组,而输出是 6 个特征图。将核形状中的 1 和 6 分别对应输入和输出的通道数是很自然的。同样,根据我们对卷积层机制的理解,核应该是 5×5 的矩阵。

此时,您可能已经猜到,在 PyTorch 卷积层中,核表示为(输出 × 输入 × 高度 × 宽度),而在 Keras 中,它表示为(高度 × 宽度 × 输入 × 输出)。

同样,您还可以看到在全连接层中,PyTorch 将核表示为(输出 × 输入),而 Keras 是(输入 × 输出)。

将权重和张量进行匹配并将它们的形状并排放置应该能使这些更加清晰。

我们还可以匹配 Keras 权重的名称和 PyTorch 张量的名称。

制作一个复制器

既然您已经了解了每个模型中的权重是什么样的,那么创建一个程序来将权重从一个模型复制到另一个模型似乎并不困难。关键是回答:

  1. 如何在每个模型中设置权重
  2. 权重在每个模型中应该是什么样的(形状和数据类型)

第一个问题可以通过之前使用内置函数 dir() 进行的检查来回答。您看到了 PyTorch 模型中的 load_state_dict 成员,它似乎就是该工具。类似地,在 Keras 模型中,您看到了一个名为 set_weight 的成员,它的名称正是 get_weight 的对应名称。您可以通过在线检查它们的文档或通过 help() 函数来进一步确认。

您已经确认这些都是函数,并且它们的文档解释了它们是您原先认为的那样。从文档中,您进一步了解到 PyTorch 模型的 load_state_dict() 函数期望的参数格式与 state_dict() 函数返回的格式相同;Keras 模型的 set_weights() 函数期望的格式与 get_weights() 函数返回的格式相同。

现在您已经完成了 Python REPL 的探索之旅(您可以输入 quit() 退出)。

通过研究如何重塑权重以及如何从一种数据类型转换到另一种数据类型,你会得到以下程序

反之,将 PyTorch 模型中的权重复制到 Keras 模型中,操作也类似,

然后,你可以通过将随机数组作为输入来验证它们是否工作相同,此时你预期输出结果会完全一致。

在我们的例子中,输出是

这些结果在足够的精度下是吻合的。请注意,由于训练的随机性,你的结果可能不完全相同。另外,由于浮点计算的特性,即使权重相同,PyTorch 和 TensorFlow/Keras 模型也不会产生完全相同的输出。

然而,这里的目标是展示你如何利用 Python 的内省工具来理解你不知道的东西并开发解决方案。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

文章

总结

在本教程中,你学习了如何在 Python REPL 下进行操作,并使用内省函数来开发解决方案。具体来说,

  • 你学习了如何在 REPL 中使用内省函数来了解对象的内部成员。
  • 你学习了如何使用 REPL 来试验 Python 代码。
  • 因此,你开发了一个在 PyTorch 和 Keras 模型之间进行转换的程序。

掌握机器学习 Python!

Python For Machine Learning

更自信地用 Python 编写代码

...从学习实用的 Python 技巧开始

在我的新电子书中探索如何实现
用于机器学习的 Python

它提供自学教程数百个可运行的代码,为您提供包括以下技能:
调试性能分析鸭子类型装饰器部署等等...

向您展示高级 Python 工具箱,用于
您的项目


查看内容

暂无评论。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。