在 PyTorch 中使用 LSTM 生成文本

循环神经网络可用于时间序列预测。其中,创建了一个回归神经网络。它也可以用作生成模型,通常是一个分类神经网络模型。生成模型旨在从数据中学习特定模式,以便在给出一些提示时,它可以创建一个与所学模式风格相同的完整输出。

在这篇文章中,您将学习如何使用 PyTorch 中的 LSTM 循环神经网络构建一个用于文本的生成模型。完成这篇文章后,您将了解:

  • 在哪里下载免费的文本语料库,用于训练文本生成模型。
  • 如何将文本序列问题构建成循环神经网络生成模型。
  • 如何开发一个 LSTM 模型,为给定问题生成合理的文本序列。

通过我的《用PyTorch进行深度学习》一书来启动你的项目。它提供了包含可用代码自学教程


让我们开始吧。

在 PyTorch 中使用 LSTM 生成文本
图片来源:Egor Lyfar。保留部分权利。

概述

这篇文章分为六个部分;它们是:

  • 什么是生成模型
  • 获取文本数据
  • 用于预测下一个字符的小型 LSTM 网络
  • 使用 LSTM 模型生成文本
  • 使用更大的 LSTM 网络
  • 使用 GPU 加快训练

什么是生成模型

生成模型实际上只是另一种能够创造新事物的机器学习模型。生成对抗网络(GAN)自成一类。使用注意力机制的 Transformer 模型也被发现对生成文本段落很有用。

它只是一种机器学习模型,因为该模型已经用现有数据进行训练,从而从中学习到了一些东西。根据训练方式的不同,它们的工作方式可能大相径庭。在这篇文章中,将创建一个基于字符的生成模型。这意味着训练一个模型,它以一串字符(字母和标点符号)作为输入,并以紧随其后的下一个字符作为目标。只要它能根据前面的字符预测下一个字符,您就可以循环运行该模型来生成一段长文本。

这个模型可能是最简单的。然而,人类语言是复杂的。您不应该期望它能产生非常高质量的输出。即便如此,您也需要大量数据并长时间训练模型才能看到合理的结果。

想开始使用PyTorch进行深度学习吗?

立即参加我的免费电子邮件速成课程(附示例代码)。

点击注册,同时获得该课程的免费PDF电子书版本。

获取文本数据

获取高质量数据对于成功的生成模型至关重要。幸运的是,许多经典文本不再受版权保护。这意味着您可以免费下载这些书籍的所有文本,并将其用于实验,例如创建生成模型。也许获取不受版权保护的免费书籍的最佳地点是古腾堡计划。

在这篇文章中,您将使用一本童年最喜欢的书作为数据集,即刘易斯·卡罗尔的《爱丽丝梦游仙境》。

您的模型将学习字符之间的依赖关系以及序列中字符的条件概率,以便您可以反过来生成全新的、原创的字符序列。这篇文章非常有趣,建议您用古腾堡计划中的其他书籍重复这些实验。这些实验不仅限于文本;您还可以尝试其他 ASCII 数据,例如计算机源代码、LaTeX、HTML 或 Markdown 等标记文档,等等。

您可以免费下载这本书的完整 ASCII 格式(纯文本 UTF-8)文本,并将其放入您的工作目录,文件名为 wonderland.txt。现在,您需要准备数据集以供建模。古腾堡计划为每本书添加了标准的页眉和页脚,这些不属于原始文本。在文本编辑器中打开文件并删除页眉和页脚。页眉很明显,以以下文本结尾:

页脚是文本行之后的所有文本,该文本行写着:

您应该得到一个大约有 3,400 行文本的文件。

预测下一个字符的小型 LSTM 网络

首先,在构建模型之前,您需要对数据进行一些预处理。神经网络模型只能处理数字,不能处理文本。因此,您需要将字符转换为数字。为了简化问题,您还希望将所有大写字母转换为小写。

在下面,您将打开文本文件,将所有字母转换为小写,并创建一个 Python 字典 char_to_int,将字符映射到不同的整数。例如,书中独特的排序小写字符列表如下:

由于此问题是基于字符的,“词汇表”是文本中使用的所有不同字符。

这应该会打印:

您可以看到这本书的字符数略低于 150,000,转换为小写后,词汇表中只有 50 个不同的字符供网络学习——这比字母表中的 26 个多得多。

接下来,您需要将文本分成输入和目标。这里使用 100 个字符的窗口。也就是说,以字符 1 到 100 作为输入,您的模型将预测字符 101。如果使用 5 个字符的窗口,单词“chapter”将变成两个数据样本:

在像这样长的文本中,可以创建无数个窗口,从而生成大量样本的数据集。

运行上述代码,您可以看到总共创建了 144,474 个样本。每个样本现在都是整数形式,使用 char_to_int 映射进行转换。然而,PyTorch 模型更喜欢以浮点张量的形式查看数据。因此,您应该将它们转换为 PyTorch 张量。模型中将使用 LSTM 层,因此输入张量应为 (样本, 时间步长, 特征) 维度。为了帮助训练,将输入归一化到 0 到 1 也是一个好主意。因此,您有以下代码:

您现在可以定义您的 LSTM 模型。在这里,您定义了一个具有 256 个隐藏单元的单个隐藏 LSTM 层。输入是单个特征(即,一个字符对应一个整数)。在 LSTM 层之后添加了一个概率为 0.2 的 dropout 层。LSTM 层的输出是一个元组,其中第一个元素是每个时间步的 LSTM 单元的隐藏状态。它是 LSTM 单元接受每个时间步输入时隐藏状态演变的历史记录。据推测,最后一个隐藏状态包含的信息最多,因此只有最后一个隐藏状态被传递到输出层。输出层是一个全连接层,用于为 50 个词汇生成 logits。logits 可以使用 softmax 函数转换为类似概率的预测。

这是一个用于 50 个类别的单字符分类模型。因此,应使用交叉熵损失。它使用 Adam 优化器进行优化。训练循环如下。为简单起见,没有创建测试集,但在每个 epoch 结束时会再次使用训练集评估模型,以跟踪进度。

这个程序可能会运行很长时间,尤其是在 CPU 上!为了保存工作成果,最好将找到的最佳模型保存起来以备将来重用。

运行上述代码可能会产生以下结果:

交叉熵几乎总是在每个 epoch 中递减。这意味着模型可能尚未完全收敛,您可以对其进行更多 epoch 的训练。训练循环完成后,您应该会创建文件 single-char.pth,其中包含找到的最佳模型权重以及该模型使用的字符到整数映射。

为完整起见,下面将以上所有内容整合到一个脚本中:

使用 LSTM 模型生成文本

假设模型训练良好,使用训练好的 LSTM 网络生成文本相对简单。首先,您需要重新创建网络并从保存的检查点加载训练好的模型权重。然后,您需要为模型创建一些起始提示。提示可以是模型能够理解的任何内容。它是一个种子序列,将提供给模型以获取一个生成的字符。然后,将生成的字符添加到该序列的末尾,并截掉第一个字符以保持一致的长度。这个过程将重复,直到您希望预测新字符(例如,一个长度为 1,000 个字符的序列)。您可以选择一个随机输入模式作为您的种子序列,然后在生成字符时打印它们。

生成提示的一种简单方法是从原始数据集中随机选择一个样本,例如,使用上一节中获得的 raw_text,可以创建如下提示:

但是您应该注意,您需要对其进行转换,因为此提示是一个字符串,而模型期望的是一个整数向量。

整个代码仅如下所示:

运行此示例首先输出使用的提示,然后是生成的每个字符。例如,下面是此文本生成器一次运行的结果。提示是:

生成的文本是:

让我们注意一下生成文本的一些观察结果。

  • 它可以发出换行符。原始文本将行宽限制为 80 个字符,生成模型试图复制此模式。
  • 字符被分成类似单词的组,有些组是实际的英文单词(例如“the”、“said”和“rabbit”),但许多不是(例如“thite”、“soteet”和“tha”)。
  • 序列中的一些单词有意义(例如“i don’t know the”),但许多没有意义(例如“he were thing”)。

这本书的这种基于字符的模型能产生这样的输出,令人印象深刻。它让您感受到了 LSTM 网络的学习能力。然而,结果并不完美。在下一节中,您将通过开发一个更大的 LSTM 网络来提高结果的质量。

使用更大的 LSTM 网络

回想一下,LSTM 是一种循环神经网络。它以序列作为输入,在序列的每一步中,输入与其内部状态混合以产生输出。因此,LSTM 的输出也是一个序列。在上述情况下,LSTM 层的最后一个时间步的输出被用于神经网络的进一步处理,而早期时间步的输出则被丢弃。然而,情况并非总是如此。您可以将一个 LSTM 层的序列输出作为另一个 LSTM 层的输入。这样,您就在构建一个更大的网络。

与卷积神经网络类似,堆叠 LSTM 网络旨在让较早的 LSTM 层学习低级特征,而较晚的 LSTM 层学习高级特征。它可能并非总是有效,但您可以尝试一下,看看模型是否能产生更好的结果。

在 PyTorch 中,构建堆叠 LSTM 层很容易。让我们将上述模型修改为以下内容:

唯一的改变是 nn.LSTM() 的参数:您将 num_layers 设置为 2 而不是 1 以添加另一个 LSTM 层。但在两个 LSTM 层之间,您还通过参数 dropout=0.2 添加了一个 dropout 层。用此模型替换之前的模型是您需要进行的所有更改。重新运行训练,您应该会看到以下结果:

您应该会看到这里的交叉熵低于上一节中的交叉熵。这意味着这个模型表现更好。事实上,有了这个模型,您可以看到生成的文本看起来更合理:

不仅单词拼写正确,文本也更像英语。由于在训练模型时交叉熵损失仍在减少,您可以假设模型尚未收敛。如果您增加训练周期,可以期望模型会更好。

为完整起见,下面是使用此新模型的完整代码,包括训练和文本生成。

使用 GPU 加快训练

运行这篇文章中的程序可能会非常慢。即使您有 GPU,您也不会立即看到改进。这是因为 PyTorch 的设计,它可能不会自动使用您的 GPU。但是,如果您有支持 CUDA 的 GPU,通过将繁重的计算从 CPU 转移走,您可以大大提高性能。

PyTorch 模型是一个张量计算程序。张量可以存储在 GPU 或 CPU 中。只要所有操作符都在同一个设备上,就可以执行操作。在这个特定的示例中,模型权重(即 LSTM 层和全连接层的权重)可以移动到 GPU。通过这样做,输入也应该在执行前移动到 GPU,除非您将其移回,否则输出也将存储在 GPU 中。

在 PyTorch 中,您可以使用以下函数检查您是否拥有支持 CUDA 的 GPU:

它返回一个布尔值,指示您是否可以使用 GPU,这反过来取决于您的硬件型号、您的操作系统是否安装了适当的库以及您的 PyTorch 是否使用相应的 GPU 支持进行编译。如果一切正常,您可以创建一个设备并将您的模型分配给它:

如果您的模型正在 CUDA 设备上运行,但您的输入张量不在,您会看到 PyTorch 对此进行抱怨并无法继续。要将您的张量移动到 CUDA 设备,您应该像下面这样运行:

其中 .to(device) 部分将发挥作用。但请记住,上面生成的 y_pred 也将在 CUDA 设备上。因此,当您运行损失函数时,您应该做同样的事情。修改上面的程序使其能够在 GPU 上运行将变成以下内容:

与上一节的代码相比,您应该会发现它们基本相同。除了通过以下行检测到 CUDA 设备:

这将是您的 GPU,如果未找到 CUDA 设备,则会回退到 CPU。之后,在几个关键位置添加了 .to(device),以将计算转移到 GPU。

进一步阅读

这种字符文本模型是使用循环神经网络生成文本的流行方式。如果您有兴趣深入了解,下面提供了一些关于该主题的更多资源和教程。

文章

论文

  • Ilya Sutskever、James Martens 和 Geoffrey Hinton。“使用循环神经网络生成文本”。载于:第 28 届国际机器学习会议论文集。美国华盛顿州贝尔维尤,2011 年。

API

总结

在这篇文章中,您了解了如何开发用于 PyTorch 文本生成的 LSTM 循环神经网络。完成这篇文章后,您将了解:

  • 如何免费查找经典书籍文本作为机器学习模型的数据集
  • 如何训练用于文本序列的 LSTM 网络
  • 如何使用 LSTM 网络生成文本序列以及如何使用 CUDA 设备优化 PyTorch 中的深度学习训练

开始使用PyTorch进行深度学习!

Deep Learning with PyTorch

学习如何构建深度学习模型

...使用新发布的PyTorch 2.0库

在我的新电子书中探索如何实现
使用 PyTorch进行深度学习

它提供了包含数百个可用代码自学教程,让你从新手变成专家。它将使你掌握:
张量操作训练评估超参数优化等等...

通过动手练习开启你的深度学习之旅


查看内容

《使用 PyTorch 中的 LSTM 进行文本生成》的 5 条回复

  1. Zineb 2023 年 4 月 4 日 晚上 9:27 #

    非常感谢这篇精彩的文章。

    torch.save([best_model, char_to_dict], “single-char.pth”) 应该是 torch.save([best_model, char_to_int], “single-char.pth”)

  2. Emmett Polhemus 2024 年 3 月 20 日 凌晨 5:05 #

    我遇到了算法不打印最终输出而是停止的问题。

    • James Carmichael 2024 年 3 月 20 日 早上 8:49 #

      Emmett 你好……请提供你可能遇到的任何具体错误信息。这将更好地帮助我们指导你。

      • Emmett polhemus 2024 年 3 月 21 日 晚上 10:10 #

        我已经让它工作了,错误与我使用的 single-char.path 文件有关,我用的是 .txt 文件。另外,每次想要提示时都需要训练它吗?

  3. shadow 2024 年 4 月 8 日 早上 10:41 #

    非常好

发表评论

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。