在 PyTorch 中使用数据集类

在机器学习和深度学习问题中,数据准备工作非常重要。数据通常很杂乱,在使用模型训练之前需要进行预处理。如果数据准备不正确,模型将无法很好地泛化。
数据预处理的一些常见步骤包括:

  • 数据归一化:包括将数据集中的数据归一化到一定的数值范围内。
  • 数据增强:通过向现有样本添加噪声或偏移特征来生成新样本,使其更加多样化。

数据准备是任何机器学习管道中的关键步骤。PyTorch 提供了许多模块,例如 torchvision,它提供了数据集和数据集类,使数据准备变得容易。

在本教程中,我们将演示如何在 PyTorch 中处理数据集和转换,以便您可以创建自己的自定义数据集类并根据需要操作数据集。具体来说,您将学到:

  • 如何创建简单的数据集类并对其应用转换。
  • 如何构建可调用转换并将其应用于数据集对象。
  • 如何对数据集对象组合各种转换。

请注意,在这里您将使用简单的数据集来理解概念,在下一部分教程中,您将有机会处理图像的数据集对象。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

在 PyTorch 中使用数据集类
图片来源:NASA。部分权利保留。

概述

本教程分为三个部分;它们是

  • 创建简单的数据集类
  • 创建可调用转换
  • 为数据集组合多个转换

创建简单的数据集类

开始之前,我们需要导入一些包才能创建数据集类。

我们将从 `torch.utils.data` 导入抽象类 `Dataset`。因此,我们在数据集类中重写以下方法:

  • __len__,以便 `len(dataset)` 可以告诉我们数据集的大小。
  • __getitem__,通过支持索引操作来访问数据集中的数据样本。例如,可以使用 `dataset[i]` 来检索第 i 个数据样本。

同样,`torch.manual_seed()` 会强制随机函数在每次重新编译时产生相同的数字。

现在,让我们定义数据集类。

在对象构造函数中,我们创建了特征和目标的值,即 `x` 和 `y`,并将它们的值赋给了张量 `self.x` 和 `self.y`。每个张量包含 20 个数据样本,而 `data_length` 属性存储数据样本的数量。稍后我们将讨论转换。

`SimpleDataset` 对象的行为与任何 Python 可迭代对象(如列表或元组)一样。现在,让我们创建 `SimpleDataset` 对象并查看其总长度和索引 1 处的值。

输出如下:

由于我们的数据集是可迭代的,让我们使用循环打印出前四个元素。

输出如下:

创建可调用转换

在许多情况下,您需要创建可调用的转换来对数据进行归一化或标准化。然后可以将这些转换应用于张量。让我们创建一个可调用的转换,并将其应用于本教程前面创建的“简单数据集”对象。

我们创建了一个简单的自定义转换 `MultDivide`,它将 `x` 乘以 2,将 `y` 除以 3。这并非用于实际用途,而是为了演示可调用类如何作为我们数据集类的转换。请记住,我们在 `simple_dataset` 中声明了一个 `transform = None` 参数。现在,我们可以用我们刚刚创建的自定义转换对象替换该 `None`。

因此,让我们演示一下如何做到这一点,并将此转换对象应用于我们的数据集,以查看它如何转换我们数据集中前四个元素。

输出如下:

正如您所见,转换已成功应用于数据集的前四个元素。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

为数据集组合多个转换

我们通常希望对数据集执行多个串联的转换。这可以通过从 torchvision 的 transforms 模块导入 `Compose` 类来实现。例如,假设我们构建另一个转换 `SubtractOne`,并将其应用于我们的数据集,同时还应用我们之前创建的 `MultDivide` 转换。

应用后,新创建的转换将从数据集的每个元素中减去 1。

如前所述,现在我们将使用 `Compose` 方法将这两个转换结合起来。

请注意,首先会将 `MultDivide` 转换应用于数据集,然后会将 `SubtractOne` 转换应用于数据集的转换后的元素。
我们将 `Compose` 对象(其中包含 `MultDivide()` 和 `SubtractOne()` 的组合)传递给我们的 `SimpleDataset` 对象。

现在已经将多个转换的组合应用于数据集,让我们打印出我们转换后的数据集的前四个元素。

总而言之,完整的代码如下:

总结

在本教程中,您学习了如何在 PyTorch 中创建自定义数据集和转换。特别是,您学到了:

  • 如何创建简单的数据集类并对其应用转换。
  • 如何构建可调用转换并将其应用于数据集对象。
  • 如何对数据集对象组合各种转换。

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

暂无评论。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。