在 PyTorch 中使用单层神经网络构建图像分类器

单层神经网络,也称为单层感知器,是最简单的神经网络类型。它仅包含一层神经元,这些神经元连接到输入层和输出层。对于图像分类器,输入层是图像,输出层是类别标签。

要使用 PyTorch 中的单层神经网络构建图像分类器,您首先需要准备数据。这通常涉及将图像和标签加载到 PyTorch 数据加载器中,然后将数据拆分为训练集和验证集。数据准备好后,您就可以定义神经网络了。

接下来,您可以使用 PyTorch 的内置函数在训练数据上训练网络,并在验证数据上评估其性能。您还需要选择一个优化器,例如随机梯度下降 (SGD),以及一个损失函数,例如交叉熵损失。

请注意,单层神经网络可能并非适用于所有任务,但它可以作为一个简单的分类器,也有助于您理解神经网络的内部工作原理并进行调试。

因此,让我们来构建我们的图像分类器。在这个过程中,您将学习

  • 如何在 PyTorch 中使用和预处理内置数据集。
  • 如何在 PyTorch 中构建和训练自定义神经网络。
  • 如何在 PyTorch 中构建分步图像分类器。
  • 如何在 PyTorch 中使用训练好的模型进行预测。

让我们开始吧。

使用 PyTorch 单层神经网络构建图像分类器。
图片由 Alex Fung 拍摄。部分权利保留。

概述

本教程分为三个部分;它们是

  • 准备数据集
  • 构建模型架构
  • 训练模型

准备数据集

在本教程中,您将使用 CIFAR-10 数据集。这是一个图像分类数据集,包含 10 个类别中 60,000 张 32×32 像素的彩色图像,每个类别有 6,000 张图像。其中有 50,000 张训练图像和 10,000 张测试图像。类别包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。CIFAR-10 是机器学习和计算机视觉研究中一个流行的数据集,因为它相对较小且简单,但又足够具有挑战性,需要使用深度学习方法。此数据集可以轻松导入到 PyTorch 库中。

以下是具体操作方法。

如果您以前从未下载过数据集,您可能会看到此代码显示图像是从何处下载的

您指定了数据集应下载到的 root 目录,通过设置 train=True 来导入训练集,设置 train=False 来导入测试集。download=True 参数将在指定 root 目录中不存在数据集时下载数据集。

构建神经网络模型

让我们定义一个继承自 torch.nn.Module 的简单神经网络 SimpleNet。网络在 __init__ 方法中有两个全连接 (fc) 层,fc1fc2。第一个全连接层 fc1 以图像作为输入,并具有 100 个隐藏神经元。同样,第二个全连接层 fc2 具有 100 个输入神经元和 num_classes 个输出神经元。由于有 10 个类别,num_classes 参数默认为 10。

此外,forward 方法定义了网络的前向传播,其中输入 x 通过 __init__ 方法中定义的层。该方法首先使用 view 方法重塑输入张量 x 以获得所需的形状。然后,输入将通过全连接层及其激活函数,最后返回输出张量。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


以下是以上所有内容的实现代码。

并且,编写一个函数来可视化这些数据,这在您稍后训练模型时也会很有用。

现在,让我们实例化模型对象。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

训练模型

您将创建 PyTorch 的 DataLoader 类的两个实例,分别用于训练和测试。在 train_loader 中,您将批次大小设置为 64,并通过设置 shuffle=True 来随机打乱训练数据。

然后,您将定义用于训练模型的交叉熵损失函数和 Adam 优化器。您为优化器设置的学习率为 0.001。

test_loader 也是类似的,只是我们不需要打乱数据。

最后,让我们设置一个训练循环来训练我们的模型几个 epoch。您将定义一些空列表来存储损失和准确率指标的值。

运行此循环将向您显示以下内容

正如您所见,单层分类器仅训练了 20 个 epoch,其验证准确率约为 47%。如果训练更多 epoch,您可能会获得不错的准确率。同样,我们的模型只有一个包含 100 个隐藏神经元的层。如果添加更多层,准确率可能会显著提高。

现在,让我们绘制损失和准确率图,看看它们的样子。

损失图是这样的:准确率图如下:

这是您可以看到模型如何根据真实标签进行预测的方法。

打印的标签如下:

这些标签对应于以下图像

总结

在本教程中,您学习了如何仅使用单层神经网络构建图像分类器。特别是,您学习了

  • 如何在 PyTorch 中使用和预处理内置数据集。
  • 如何在 PyTorch 中构建和训练自定义神经网络。
  • 如何在 PyTorch 中构建分步图像分类器。
  • 如何在 PyTorch 中使用训练好的模型进行预测。

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

12 条对 使用 PyTorch 单层神经网络构建图像分类器 的回复

  1. Tony the Riger 2023年1月20日 上午8:35 #

    杰森在哪里?

  2. Chuck 2023年2月2日 上午8:08 #

    模型工作正常,但之后的代码,即:
    准确率/损失图和预测图像,
    内核因此消息而崩溃:
    ‘内核似乎已崩溃。它将自动重启’。我试了好几次,结果都一样。有什么建议解决这个问题吗?谢谢。

  3. Leo 2023年2月24日 上午8:23 #

    # 创建 Data 对象
    “dataset = Data()”

    Pytorch 报告了后续错误。数据类未创建?

    • James Carmichael 2023 年 2 月 24 日上午 11:03 #

      你好 Leo… 请详细说明您的问题,以便我们能更好地帮助您。也就是说…您收到的是什么错误?

      • Fabio 2023 年 4 月 20 日下午 8:16 #

        在本教程中,它可能被遗漏了,但在其他教程中存在。

        在 dataset = Data() 之前使用此代码。

        # 创建数据集类
        class Data(Dataset)
        def __init__(self)
        self.x = torch.arange(-2, 2, 0.1).view(-1, 1)
        self.y = torch.zeros(self.x.shape[0], 1)
        self.y[self.x[:, 0] > 0.2] = 1
        self.len = self.x.shape[0]

        def __getitem__(self, idx)
        return self.x[idx], self.y[idx]

        def __len__(self)
        return self.len

        • James Carmichael 2023 年 4 月 21 日上午 9:29 #

          感谢 Fabio 的反馈和建议!

          • Narae 2023 年 7 月 31 日下午 5:22 #

            你好,

            我认为

            dataset = Data()

            部分是错误地留在代码中的?它在后面的代码中似乎没有被使用,而且 Data() 方法在使用前从未定义过。

          • James Carmichael 2023 年 8 月 1 日上午 9:15 #

            你好 Narae… 感谢您的反馈!

  4. Leo 2023 年 2 月 24 日上午 11:37 #

    你好 Kames,
    谢谢回复。

    代码和错误在这里:

    ————————————————-

    import torch
    import torchvision
    import torchvision.transforms as transforms

    # 导入 CIFAR-10 数据集
    train_set = torchvision.datasets.CIFAR10(root=’./data’, train=True, download=True, transform=transforms.ToTensor())
    test_set = torchvision.datasets.CIFAR10(root=’./data’, train=False, download=True, transform=transforms.ToTensor())

    # 创建 Data 对象
    dataset = Data()

    ————————————————————–

    NameError Traceback (最近一次调用)
    Input In [1], in ()
    7 test_set = torchvision.datasets.CIFAR10(root=’./data’, train=False, download=True, transform=transforms.ToTensor())
    9 # Create the Data object
    —> 10 dataset = Data()

    NameError: name ‘Data’ is not defined

  5. Sean O'Connor 2023 年 4 月 1 日下午 3:46 #

    即使是单层神经网络,其宽度扩展性也不好。
    如果宽度是 n,则所需的乘加运算次数是 n 的平方。
    从 n=8 的合理值(给出 64 次乘加运算)开始,到 n=256(给出 65536 次乘加运算)时变得不合理。
    但是,通过使用组合算法,您可以控制成本。
    https://ai462qqq.blogspot.com/2023/03/switch-net-4-reducing-cost-of-neural.html

    • James Carmichael 2023 年 4 月 2 日上午 6:21 #

      感谢 Sean 的反馈和贡献!

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。