使用 DataLoader 和 Dataset 训练 PyTorch 模型

当你构建和训练一个 PyTorch 深度学习模型时,你可以通过几种不同的方式提供训练数据。最终,PyTorch 模型的工作方式就像一个函数,它接收一个 PyTorch 张量并返回另一个张量。你可以自由选择如何获取输入张量。最简单的方法可能是准备一个包含整个数据集的大张量,并在每个训练步骤中从中提取一小批。但是你会发现使用 DataLoader 可以为你节省处理数据的一些代码行。

在这篇文章中,你将看到如何在 PyTorch 中使用 Data 和 DataLoader。阅读完本文后,你将学会:

  • 如何创建和使用 DataLoader 来训练你的 PyTorch 模型
  • 如何使用 Data 类即时生成数据

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

使用 DataLoader 和 Dataset 训练 PyTorch 模型
图片来源:Emmanuel Appiah。保留部分权利。

概述

这篇博文分为三部分;它们是:

  • 什么是 DataLoader
  • 在训练循环中使用 DataLoader

什么是 DataLoader

要训练深度学习模型,你需要数据。通常数据以数据集的形式提供。在一个数据集中,有许多数据样本或实例。你可以要求模型一次处理一个样本,但通常你会让模型一次处理一批几个样本。你可以通过使用张量上的切片语法从数据集中提取一个切片来创建一个批次。为了获得更好的训练质量,你可能还希望在每个 epoch 中打乱整个数据集,这样在整个训练循环中就没有两个批次是相同的。有时,你可能会引入**数据增强**来手动增加数据的方差。这在图像相关任务中很常见,你可以随机倾斜或缩放图像一点,从少量图像生成大量数据样本。

你可以想象要编写很多代码来完成所有这些操作。但是使用 DataLoader 会容易得多。

以下是创建 DataLoader 并从中获取一个批次的示例。在此示例中,使用了 声纳数据集,最终将其转换为 PyTorch 张量并传递给 DataLoader

从上面的输出可以看出,X_batchy_batch 都是 PyTorch 张量。loaderDataLoader 类的一个实例,它可以像迭代器一样工作。每次从它读取时,都会从原始数据集中获取一批特征和目标。

当你创建一个 DataLoader 实例时,你需要提供一个样本对列表。每个样本对是一个特征数据样本和对应的目标。需要列表是因为 DataLoader 希望使用 len() 找到数据集的总大小,并使用数组索引检索特定样本。批处理大小是 DataLoader 的一个参数,因此它知道如何从整个数据集中创建一个批次。你应该几乎总是使用 shuffle=True,这样每次加载数据时,样本都会被打乱。这对于训练很有用,因为在每个 epoch 中,你将读取每个批次一次。当你从一个 epoch 到另一个 epoch 时,由于 DataLoader 知道你已经用完了所有批次,它会重新打乱,这样你就会得到一个新的样本组合。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

在训练循环中使用 DataLoader

以下是在训练循环中使用 DataLoader 的示例:

你可以看到,一旦创建了 DataLoader 实例,训练循环就会变得更容易。在上面,只有训练集是用 DataLoader 打包的,因为你需要以批次的形式遍历它。你也可以为测试集创建一个 DataLoader 并用于模型评估,但由于准确性是在整个测试集上计算而不是在批次中计算,所以 DataLoader 的好处不明显。

综合所有内容,以下是完整的代码。

使用 Dataset 类创建数据迭代器

在 PyTorch 中,有一个 Dataset 类可以与 DataLoader 类紧密结合。回想一下,DataLoader 期望其第一个参数可以与 len() 和数组索引一起使用。Dataset 类是为此目的提供的基类。你可能希望使用 Dataset 类的原因是在获取数据样本之前需要进行一些特殊处理。例如,数据应该从数据库或磁盘中读取,并且你只希望在内存中保留少量样本,而不是预取所有内容。另一个例子是实时数据预处理,例如图像任务中常见的随机增强。

要使用 Dataset 类,只需从它派生并实现两个成员函数。下面是一个例子:

这并不是使用 Dataset 最强大的方式,但足以演示它是如何工作的。有了它,你可以创建一个 DataLoader 并将其用于模型训练。修改之前的示例,你将得到以下内容:

你将 dataset 设置为 SonarDataset 的一个实例,你实现了 __len__()__getitem__() 函数。这取代了前面示例中的列表来设置 DataLoader 实例。之后,训练循环中的所有内容都相同。请注意,你仍然在示例中直接为测试集使用 PyTorch 张量。

__getitem__() 函数中,你接收一个整数,该整数充当数组索引,并返回一个特征和目标对。你可以在此函数中实现任何功能:运行一些代码以生成合成数据样本,从互联网上即时读取数据,或向数据添加随机变化。当你无法将整个数据集保留在内存中时,你会发现它也很有用,因此你只需加载所需的样本。

事实上,既然你已经创建了一个 PyTorch 数据集,你就不需要使用 scikit-learn 来将数据分成训练集和测试集。在 torch.utils.data 子模块中,你有一个 random_split() 函数,它与 Dataset 类用于相同的目的。下面是一个完整的示例:

它与你之前的示例非常相似。请注意,PyTorch 模型仍然需要张量作为输入,而不是 Dataset。因此,在上面,你需要使用 default_collate() 函数将数据集中的样本收集到张量中。

进一步阅读

如果您想深入了解,本节提供了更多关于该主题的资源。

总结

在这篇文章中,你学习了如何使用 DataLoader 创建打乱的批次数据以及如何使用 Dataset 提供数据样本。具体来说,你学习了

  • DataLoader 是一种向训练循环提供批次数据的便捷方式
  • 如何使用 Dataset 产生数据样本
  • 如何结合 DatasetDataLoader 即时生成批次数据以进行模型训练

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

使用 DataLoader 和 Dataset 训练 PyTorch 模型 的 5 条回复

  1. Aditya 2023 年 2 月 26 日上午 12:32 #

    你好,对于预测问题,打乱数据是否合适?

    • Adrian Tam
      Adrian Tam 2023 年 3 月 15 日上午 5:43 #

      通常不会。听起来你正在谈论时间序列问题,我们不希望丢失时间顺序信息。因此,不建议打乱数据。但是你可以将时间序列转换为窗口并打乱这些窗口。希望这有帮助。

  2. Matthew Avaylon 2023 年 8 月 8 日上午 2:52 #

    在介绍中你提到 PyTorch 模型可以处理大型数据张量。这是否意味着我可以加载 MNIST 数据集,将所有训练数据设置为 x_train 和 y_train 作为张量并这样训练?

    for epoch in range(20)

    print(“epoch:” + str(epoch))

    model.train() # 将模型设置为训练模式

    y_pred = model(X_train) # 前向传播

    loss_calc = loss_func(y_pred, y_train)

    optimizer.zero_grad()

    loss_calc.backward()

    optimizer.step()

    我对 dataloader 到底做了什么有点困惑。加载器以批次的形式迭代数据并将批次输入到模型中。这是否意味着它正在将这些批次加载到内存中,其中每个批次可以被视为输入到模型中的一小部分数据张量?就像我开头说的,我们可以输入整个数据张量,而这里输入的是一个批次的张量。

  3. Peggy 2024 年 6 月 13 日下午 12:08 #

    对于多实例学习 (MIL),我的数据集包含用于训练的唯一 ID、特征和标签。

    对于预测,我需要提供包含唯一 ID 和特征,但不包含标签的数据集。

    因此,我想问一下,我是否应该修改我的数据集类,以便在预测期间处理没有标签的数据?

    谢谢!

    • James Carmichael 2024 年 6 月 14 日上午 6:48 #

      嗨,Peggy…是的,您应该修改您的数据集类,以便在多实例学习 (MIL) 的预测期间处理没有标签的数据。通常,这涉及创建一个可以管理训练(带标签)和预测(不带标签)场景的数据集类。

      这是修改数据集类的一般方法:

      ### 1. 定义数据集类
      您可以创建一个数据集类,该类接受带标签和不带标签的数据。该类应能根据标签的存在来区分它是用于训练还是预测。

      ### 2. 处理不同场景
      您可以添加一个参数来指示数据集是否包含标签。如果未提供标签,则该类应在预测期间相应地处理数据。

      ### PyTorch 示例

      以下是 PyTorch 中的一个基本示例,用于说明这一点:

      python
      import torch
      from torch.utils.data import Dataset

      class MILDataset(Dataset)
      def __init__(self, data, labels=None, mode='train')
      """
      Args
      data (list or array-like): 特征或实例列表。
      labels (list or array-like, optional): 与数据对应的标签列表。默认为 None。
      mode (str): 'train' 或 'predict',指示操作模式。默认为 'train'。
      """
      self.data = data
      self.labels = labels
      self.mode = mode

      def __len__(self)
      return len(self.data)

      def __getitem__(self, idx)
      if torch.is_tensor(idx)
      idx = idx.tolist()

      sample = self.data[idx]

      if self.mode == 'train'
      if self.labels is None
      raise ValueError("在训练模式下必须提供标签。")
      label = self.labels[idx]
      return sample, label
      elif self.mode == 'predict'
      return sample
      else
      raise ValueError("模式应为 'train' 或 'predict'。")

      # 示例用法
      # 用于训练
      train_data = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
      train_labels = [0, 1, 0]
      train_dataset = MILDataset(data=train_data, labels=train_labels, mode='train')

      # 用于预测
      predict_data = [[0.7, 0.8], [0.9, 1.0]]
      predict_dataset = MILDataset(data=predict_data, mode='predict')

      ### 解释
      – **初始化 (__init__ 方法)**
      data 参数保存特征。
      labels 参数是可选的,仅在训练模式下需要。
      mode 参数指定数据集是用于训练还是预测。

      – **长度 (__len__ 方法)**
      – 返回数据集中实例的数量。

      – **获取项 (__getitem__ 方法)**
      – 如果处于训练模式 ('train'),它返回一个 (样本, 标签) 元组。
      – 如果处于预测模式 ('predict'),它只返回样本(特征向量)。

      ### 将数据集与 DataLoader 一起使用
      您可以将此数据集类与 PyTorch 的 DataLoader 一起用于训练和预测:

      python
      from torch.utils.data import DataLoader

      # 用于训练
      train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

      # 用于预测
      predict_loader = DataLoader(predict_dataset, batch_size=2, shuffle=False)

      这种结构使您的数据集类具有灵活性,并能有效处理训练(带标签)和预测(不带标签)场景。

发表回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。