tensorflow.data API 简介

当你构建和训练 Keras 深度学习模型时,你可以通过多种方式提供训练数据。将数据表示为 NumPy 数组或 TensorFlow 张量是很常见的。另一种方式是创建一个 Python 生成器函数,让训练循环从中读取数据。还有一种提供数据的方式是使用 tf.data 数据集。

在本教程中,你将看到如何为 Keras 模型使用 tf.data 数据集。完成本教程后,你将学会:

  • 如何创建和使用 tf.data 数据集
  • 与生成器函数相比,这样做的优点

让我们开始吧。

TensorFlow.data API 入门指南
图片来源:Monika MG。部分权利保留。

概述

本文分为四个部分,它们是:

  • 使用 NumPy 数组和生成器函数训练 Keras 模型
  • 使用 tf.data 创建数据集
  • 从生成器函数创建数据集
  • 带预取的数据集

使用 NumPy 数组和生成器函数训练 Keras 模型

在了解 tf.data API 如何工作之前,让我们回顾一下通常如何训练 Keras 模型。

首先,你需要一个数据集。一个例子是 Keras API 自带的 fashion MNIST 数据集。该数据集包含 60,000 个训练样本和 10,000 个测试样本,均为 28×28 像素的灰度图,对应的分类标签用整数 0 到 9 编码。

数据集是一个 NumPy 数组。然后你可以构建一个用于分类的 Keras 模型,并使用模型的 fit() 函数将 NumPy 数组作为数据提供。

完整代码如下:

运行此代码将打印以下内容:

此外,还会创建以下图表,显示模型训练 50 个 epoch 时的验证准确率:

用同样网络训练的另一种方式是提供来自 Python 生成器函数的数据,而不是 NumPy 数组。生成器函数是指包含 yield 语句的函数,它在函数运行时向数据使用者发出数据。fashion MNIST 数据集的生成器可以如下创建:

此函数应使用语法 batch_generator(train_image, train_label, 32) 调用。它将无限地分批扫描输入数组。一旦到达数组末尾,它将从头开始重新开始。

使用生成器训练 Keras 模型与使用 fit() 函数类似。

您不需要提供数据和标签,只需提供生成器即可,因为它会同时输出两者。当数据以 NumPy 数组形式提供时,您可以通过查看数组的长度来知道样本的数量。Keras 在使用完整个数据集一次时就可以完成一个 epoch。但是,您的生成器函数会无限地发出批次,因此您需要使用 fit() 函数的 steps_per_epoch 参数来告诉它何时结束一个 epoch。

在上面的代码中,验证数据是作为 NumPy 数组提供的,但您也可以改用生成器,并指定 validation_steps 参数。

以下是使用生成器函数的完整代码,其输出与前面的示例相同:

使用 tf.data 创建数据集

给定 fashion MNIST 数据已加载,您可以将其转换为 tf.data 数据集,如下所示:

这将按如下方式打印数据集的规范:

您可以看到数据是一个元组(因为元组是传递给 from_tensor_slices() 函数的参数),其中第一个元素形状为 (28,28),而第二个元素是标量。两个元素都存储为 8 位无符号整数。

如果您在创建数据集时不将数据作为两个 NumPy 数组的元组提供,您也可以稍后进行。下面的代码创建了相同的数据集,但首先为图像数据和标签分别创建数据集,然后再将它们组合起来:

这将打印相同的规范。

数据集中的 zip() 函数就像 Python 中的 zip() 函数一样,因为它将多个数据集中的数据逐一匹配成一个元组。

使用 tf.data 数据集的一个好处是处理数据的灵活性。以下是如何使用数据集训练 Keras 模型(其中批次大小设置为数据集):

这是使用数据集最简单的用例。如果您深入研究,您会发现数据集只是一个迭代器。因此,您可以使用以下方法打印数据集中的每个样本:

数据集内置了许多函数。前面使用的 batch() 就是其中之一。如果您从数据集中创建批次并打印它们,您会得到以下结果:

在这里,批次中的每个项不是一个样本,而是样本的一个批次。您还可以使用 map()filter()reduce() 等函数进行序列转换,或者使用 concatendate()interleave() 与其他数据集进行组合。还有 repeat()take()take_while()skip() 等函数,就像 Python 的 itertools 模块中我们熟悉的功能一样。函数列表的完整内容可以在 API 文档中找到。

从生成器函数创建数据集

到目前为止,您已经看到了如何在 Keras 模型训练中将数据集用于 NumPy 数组。实际上,数据集也可以从生成器函数创建。但是,与上面示例中生成一个**批次**的生成器函数不同,现在您创建一个生成一个样本的生成器函数。以下是该函数:

此函数通过随机化索引向量来随机化输入数组。然后它一次生成一个样本。与前面的示例不同,此生成器将在数组中的样本用尽时结束。

您可以使用 from_generator() 从该函数创建数据集。您需要提供生成器函数的名称(而不是实例化的生成器),还需要提供数据集的输出签名。这是必需的,因为 tf.data.Dataset API 在生成器被消耗之前无法推断数据集规范。

运行上述代码将打印与之前相同的规范。

这样的数据集在功能上与您之前创建的数据集等效。因此,您可以像以前一样将其用于训练。以下是完整代码:

带预取的数据集

使用数据集的真正好处是使用 prefetch()

使用 NumPy 数组进行训练可能在性能上是最好的。但是,这意味着您需要将所有数据加载到内存中。使用生成器函数进行训练允许您一次准备一个批次,其中数据可以按需从磁盘加载,例如。但是,使用生成器函数训练 Keras 模型意味着在任何时候运行的都是训练循环或生成器函数。让生成器函数和 Keras 的训练循环并行运行并不容易。

数据集是允许生成器和训练循环并行运行的 API。如果您的生成器计算成本很高(例如,实时执行图像增强),您可以从这样的生成器函数创建数据集,然后将其与 prefetch() 一起使用,如下所示:

prefetch() 的数字参数是缓冲区大小。在这里,数据集被要求将三个批次保存在内存中,供训练循环使用。每当消耗一个批次时,数据集 API 将恢复生成器函数,在后台异步地重新填充缓冲区。因此,您可以让训练循环和生成器函数中的数据准备算法并行运行。

值得一提的是,在上一节中,您为数据集 API 创建了一个随机化生成器。实际上,数据集 API 也有一个 shuffle() 函数来执行相同的操作,但除非数据集足够小以适合内存,否则您可能不想使用它。

shuffle() 函数与 prefetch() 一样,接受一个缓冲区大小参数。随机化算法会用数据集填充缓冲区,并从中随机抽取一个元素。消耗的元素将被数据集中的下一个元素替换。因此,您需要一个与数据集本身一样大的缓冲区才能进行真正的随机打乱。以下代码段演示了此限制:

上述输出如下所示:

在这里,您可以看到数字在附近被随机打乱,并且您永远不会在输出中看到大数字。

进一步阅读

有关 tf.data 数据集的更多信息,请参阅其 API 文档。

总结

在这篇文章中,您已经了解了如何使用 tf.data 数据集以及它如何在 Keras 模型训练中使用。

具体来说,你学到了:

  • 如何使用 NumPy 数组、生成器和数据集中的数据训练模型
  • 如何使用 NumPy 数组或生成器函数创建数据集
  • 如何将预取与数据集结合使用,使生成器和训练循环并行运行

4 条对 TensorFlow.data API 入门指南 的回复

  1. Lukas 2022年8月18日 凌晨4:33 #

    抱歉,没看到 Adrian 写了这篇教程。
    感谢 Adrian,我的朋友!

    • James Carmichael 2022年8月18日 上午10:54 #

      Lukas,反馈很棒!非常感谢!

  2. david 2022年8月21日 凌晨1:27 #

    嗨,Jason博士
    感谢您发表所有精彩的帖子。我真的很喜欢它们。

    • James Carmichael 2022年8月21日 上午7:49 #

      David,谢谢您的反馈和支持!我们非常感激!

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。