在机器学习和深度学习问题中,数据准备工作非常重要。数据通常很杂乱,在使用模型训练之前需要进行预处理。如果数据准备不正确,模型将无法很好地泛化。
数据预处理的一些常见步骤包括:
- 数据归一化:包括将数据集中的数据归一化到一定的数值范围内。
- 数据增强:通过向现有样本添加噪声或偏移特征来生成新样本,使其更加多样化。
数据准备是任何机器学习管道中的关键步骤。PyTorch 提供了许多模块,例如 torchvision,它提供了数据集和数据集类,使数据准备变得容易。
在本教程中,我们将演示如何在 PyTorch 中处理数据集和转换,以便您可以创建自己的自定义数据集类并根据需要操作数据集。具体来说,您将学到:
- 如何创建简单的数据集类并对其应用转换。
- 如何构建可调用转换并将其应用于数据集对象。
- 如何对数据集对象组合各种转换。
请注意,在这里您将使用简单的数据集来理解概念,在下一部分教程中,您将有机会处理图像的数据集对象。
通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码的自学教程。
让我们开始吧。

在 PyTorch 中使用数据集类
图片来源:NASA。部分权利保留。
概述
本教程分为三个部分;它们是
- 创建简单的数据集类
- 创建可调用转换
- 为数据集组合多个转换
创建简单的数据集类
开始之前,我们需要导入一些包才能创建数据集类。
1 2 3 |
import torch from torch.utils.data import Dataset torch.manual_seed(42) |
我们将从 `torch.utils.data` 导入抽象类 `Dataset`。因此,我们在数据集类中重写以下方法:
__len__
,以便 `len(dataset)` 可以告诉我们数据集的大小。__getitem__
,通过支持索引操作来访问数据集中的数据样本。例如,可以使用 `dataset[i]` 来检索第 i 个数据样本。
同样,`torch.manual_seed()` 会强制随机函数在每次重新编译时产生相同的数字。
现在,让我们定义数据集类。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
class SimpleDataset(Dataset): # 定义构造函数中的值 def __init__(self, data_length = 20, transform = None): self.x = 3 * torch.eye(data_length, 2) self.y = torch.eye(data_length, 4) self.transform = transform self.len = data_length # 获取数据样本 def __getitem__(self, idx): sample = self.x[idx], self.y[idx] if self.transform: sample = self.transform(sample) return sample # 获取数据大小/长度 def __len__(self): return self.len |
在对象构造函数中,我们创建了特征和目标的值,即 `x` 和 `y`,并将它们的值赋给了张量 `self.x` 和 `self.y`。每个张量包含 20 个数据样本,而 `data_length` 属性存储数据样本的数量。稍后我们将讨论转换。
`SimpleDataset` 对象的行为与任何 Python 可迭代对象(如列表或元组)一样。现在,让我们创建 `SimpleDataset` 对象并查看其总长度和索引 1 处的值。
1 2 3 |
dataset = SimpleDataset() print("simple_dataset 对象长度: ", len(dataset)) print("访问 simple_dataset 对象索引 1 处的值: ", dataset[1]) |
输出如下:
1 2 |
simple_dataset 对象长度: 20 访问 simple_dataset 对象索引 1 处的值: (tensor([0., 3.]), tensor([0., 1., 0., 0.])) |
由于我们的数据集是可迭代的,让我们使用循环打印出前四个元素。
1 2 3 |
for i in range(4): x, y = dataset[i] print(x, y) |
输出如下:
1 2 3 4 |
tensor([3., 0.]) tensor([1., 0., 0., 0.]) tensor([0., 3.]) tensor([0., 1., 0., 0.]) tensor([0., 0.]) tensor([0., 0., 1., 0.]) tensor([0., 0.]) tensor([0., 0., 0., 1.]) |
创建可调用转换
在许多情况下,您需要创建可调用的转换来对数据进行归一化或标准化。然后可以将这些转换应用于张量。让我们创建一个可调用的转换,并将其应用于本教程前面创建的“简单数据集”对象。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# 创建可调用转换类 mult_divide class MultDivide: # 构造函数 def __init__(self, mult_x = 2, divide_y = 3): self.mult_x = mult_x self.divide_y = divide_y # 调用器 def __call__(self, sample): x = sample[0] y = sample[1] x = x * self.mult_x y = y / self.divide_y sample = x, y return sample |
我们创建了一个简单的自定义转换 `MultDivide`,它将 `x` 乘以 2,将 `y` 除以 3。这并非用于实际用途,而是为了演示可调用类如何作为我们数据集类的转换。请记住,我们在 `simple_dataset` 中声明了一个 `transform = None` 参数。现在,我们可以用我们刚刚创建的自定义转换对象替换该 `None`。
因此,让我们演示一下如何做到这一点,并将此转换对象应用于我们的数据集,以查看它如何转换我们数据集中前四个元素。
1 2 3 4 5 6 7 8 9 |
# 调用转换对象 mul_div = MultDivide() custom_dataset = SimpleDataset(transform = mul_div) for i in range(4): x, y = dataset[i] print('索引: ', i, '原始_x: ', x, '原始_y: ', y) x_, y_ = custom_dataset[i] print('索引: ', i, '转换后的_x:', x_, '转换后的_y:', y_) |
输出如下:
1 2 3 4 5 6 7 8 |
索引: 0 原始_x: tensor([3., 0.]) 原始_y: tensor([1., 0., 0., 0.]) 索引: 0 转换后的_x: tensor([6., 0.]) 转换后的_y: tensor([0.3333, 0.0000, 0.0000, 0.0000]) 索引: 1 原始_x: tensor([0., 3.]) 原始_y: tensor([0., 1., 0., 0.]) 索引: 1 转换后的_x: tensor([0., 6.]) 转换后的_y: tensor([0.0000, 0.3333, 0.0000, 0.0000]) 索引: 2 原始_x: tensor([0., 0.]) 原始_y: tensor([0., 0., 1., 0.]) 索引: 2 转换后的_x: tensor([0., 0.]) 转换后的_y: tensor([0.0000, 0.0000, 0.3333, 0.0000]) 索引: 3 原始_x: tensor([0., 0.]) 原始_y: tensor([0., 0., 0., 1.]) 索引: 3 转换后的_x: tensor([0., 0.]) 转换后的_y: tensor([0.0000, 0.0000, 0.0000, 0.3333]) |
正如您所见,转换已成功应用于数据集的前四个元素。
想开始使用PyTorch进行深度学习吗?
立即参加我的免费电子邮件速成课程(附示例代码)。
点击注册,同时获得该课程的免费PDF电子书版本。
为数据集组合多个转换
我们通常希望对数据集执行多个串联的转换。这可以通过从 torchvision 的 transforms 模块导入 `Compose` 类来实现。例如,假设我们构建另一个转换 `SubtractOne`,并将其应用于我们的数据集,同时还应用我们之前创建的 `MultDivide` 转换。
应用后,新创建的转换将从数据集的每个元素中减去 1。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from torchvision import transforms # 创建 subtract_one 转换 class SubtractOne: # 构造函数 def __init__(self, number = 1): self.number = number # 调用器 def __call__(self, sample): x = sample[0] y = sample[1] x = x - self.number y = y - self.number sample = x, y return sample |
如前所述,现在我们将使用 `Compose` 方法将这两个转换结合起来。
1 2 |
# 组合多个转换 mult_transforms = transforms.Compose([MultDivide(), SubtractOne()]) |
请注意,首先会将 `MultDivide` 转换应用于数据集,然后会将 `SubtractOne` 转换应用于数据集的转换后的元素。
我们将 `Compose` 对象(其中包含 `MultDivide()` 和 `SubtractOne()` 的组合)传递给我们的 `SimpleDataset` 对象。
1 2 |
# 创建一个带有多个转换的新 simple_dataset 对象 new_dataset = SimpleDataset(transform = mult_transforms) |
现在已经将多个转换的组合应用于数据集,让我们打印出我们转换后的数据集的前四个元素。
1 2 3 4 5 |
for i in range(4): x, y = dataset[i] print('索引: ', i, '原始_x: ', x, '原始_y: ', y) x_, y_ = new_dataset[i] print('索引: ', i, '转换后的 x_:', x_, '转换后的 y_:', y_) |
总而言之,完整的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch from torch.utils.data import Dataset from torchvision import transforms torch.manual_seed(2) class SimpleDataset(Dataset): # 定义构造函数中的值 def __init__(self, data_length = 20, transform = None): self.x = 3 * torch.eye(data_length, 2) self.y = torch.eye(data_length, 4) self.transform = transform self.len = data_length # 获取数据样本 def __getitem__(self, idx): sample = self.x[idx], self.y[idx] if self.transform: sample = self.transform(sample) return sample # 获取数据大小/长度 def __len__(self): return self.len # 创建可调用转换类 mult_divide class MultDivide: # 构造函数 def __init__(self, mult_x = 2, divide_y = 3): self.mult_x = mult_x self.divide_y = divide_y # 调用器 def __call__(self, sample): x = sample[0] y = sample[1] x = x * self.mult_x y = y / self.divide_y sample = x, y return sample # 创建 subtract_one 转换 class SubtractOne: # 构造函数 def __init__(self, number = 1): self.number = number # 调用器 def __call__(self, sample): x = sample[0] y = sample[1] x = x - self.number y = y - self.number sample = x, y return sample # 组合多个转换 mult_transforms = transforms.Compose([MultDivide(), SubtractOne()]) # 创建一个带有多个转换的新 simple_dataset 对象 dataset = SimpleDataset() new_dataset = SimpleDataset(transform = mult_transforms) print("simple_dataset 对象长度: ", len(dataset)) print("访问 simple_dataset 对象索引 1 处的值: ", dataset[1]) for i in range(4): x, y = dataset[i] print('索引: ', i, '原始_x: ', x, '原始_y: ', y) x_, y_ = new_dataset[i] print('索引: ', i, '转换后的 x_:', x_, '转换后的 y_:', y_) |
总结
在本教程中,您学习了如何在 PyTorch 中创建自定义数据集和转换。特别是,您学到了:
- 如何创建简单的数据集类并对其应用转换。
- 如何构建可调用转换并将其应用于数据集对象。
- 如何对数据集对象组合各种转换。
暂无评论。