文本生成是深度学习最引人入胜的应用之一。随着 GPT-2 等大型语言模型的出现,我们现在可以生成连贯、上下文相关且极具创造性的人类般文本。在本教程中,您将学习如何使用 GPT-2 实现文本生成。您将通过可以立即运行的实际示例进行学习,通过本指南,您将理解其理论和实际实现细节。
完成本教程后,您将了解:
- GPT-2 的 Transformer 架构如何实现复杂的文本生成
- 如何使用不同的采样策略实现文本生成
- 如何针对不同的用例优化生成参数
通过我的书籍《Hugging Face Transformers中的NLP》,快速启动您的项目。它提供了带有工作代码的自学教程。
让我们开始吧。

使用 GPT-2 模型进行文本生成
图片由 Peter Herrmann 提供。保留部分权利。
概述
本教程分为四个部分;它们是:
- 核心文本生成实现
- 文本生成中的参数有哪些?
- 批处理和填充
- 优化生成结果的技巧
核心文本生成实现
让我们从一个演示基本概念的基本实现开始。下面,您将创建一个类,使用预训练的 GPT-2 模型,根据给定的提示生成文本。您将在本教程的后续部分扩展此类别。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer class TextGenerator: def __init__(self, model_name='gpt2'): """使用预训练模型初始化文本生成器。 参数 model_name (str): 要使用的预训练模型的名称。 可以是 'gpt2'、'gpt2-medium'、'gpt2-large' 中的任意一个 """ self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) self.model = GPT2LMHeadModel.from_pretrained(model_name) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model.to(self.device) def generate_text(self, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95): """根据输入提示生成文本。 参数 prompt (str): 用于继续生成的输入文本 max_length (int): 生成文本的最大长度 temperature (float): 控制生成中的随机性 top_k (int): 要考虑的最高概率令牌数量 top_p (float): 用于令牌过滤的累积概率阈值 Returns str: 包含提示的生成文本 """ try: # 编码输入提示 inputs = self.tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs["attention_mask"].to(self.device) # 配置生成参数 gen_kwargs = { "max_length": max_length, "temperature": temperature, "top_k": top_k, "top_p": top_p, "pad_token_id": self.tokenizer.eos_token_id, "no_repeat_ngram_size": 2, "do_sample": True, } # 生成文本 with torch.no_grad(): output_sequences = self.model.generate( input_ids, attention_mask=attention_mask, **gen_kwargs ) # 解码并返回生成的文本 generated_text = self.tokenizer.decode( output_sequences[0], skip_special_tokens=True ) return generated_text except Exception as e: print(f"Error during text generation: {str(e)}") return prompt |
让我们分解一下这个实现。
在此代码中,您使用 `transformers` 库中的 `GPT2LMHeadModel` 和 `GPT2Tokenizer` 类来加载预训练的 GPT-2 模型和分词器。作为用户,您甚至不需要了解 GPT-2 的工作原理。`TextGenerator` 类托管它们,如果您有 GPU,则在 GPU 上使用它们。如果您尚未安装该库,可以使用 `pip` 命令进行安装。
1 |
pip install transformers torch |
在 `generate_text` 方法中,您处理核心生成过程,其中包含几个重要参数:
- `max_length`:控制生成文本的最大长度
- `temperature`:调整随机性(值越高 = 越有创意)
- `top_k`:将词汇限制为 $k$ 个最高概率的令牌
- `top_p`:使用核采样动态限制令牌
以下是使用此实现生成文本的方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
... # 创建一个文本生成器实例 generator = TextGenerator() # 示例 1:基本文本生成 prompt = "人工智能的未来将" generated_text = generator.generate_text(prompt) print(f"生成的文本:\n{generated_text}\n") # 示例 2:使用更高温度进行更具创造性的生成 creative_text = generator.generate_text( prompt="曾几何时", temperature=0.9, max_length=200 ) print(f"创造性生成:\n{creative_text}\n") # 示例 3:使用更低温度进行更集中的生成 focused_text = generator.generate_text( prompt="机器学习的好处包括", temperature=0.5, max_length=150 ) print(f"集中生成:\n{focused_text}\n") |
输出可能为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
生成的文本 人工智能的未来将取决于它学习的程度以及如何适应新情况。 未来也可能不像我们想象的那么好。就目前而言,我们正在处理比人脑更复杂的人工智能。它会做我们无法控制的事情,例如玩电脑来寻找线索,让您能够阻止汽车移动。但如果我们能弄清楚如何使用 创造性生成 曾几何时就是这样。当我拍这张照片时,我也有类似的情况。其他情况下也发生过这种情况。 以及给经历过此问题的读者的提示 '我想你可能也遇到过同样的问题,'我不知道你还会继续这样做多久。尽量让它尽可能多地通过你。周围有大量的负能量,你可以和你的朋友、家人和同事一起尝试。尽你所能去理解它。你不需要在这方面做得非常好。' -约翰·L·戈塞特,前中情局官员 . 集中生成 机器学习的好处包括 提高了预测的准确性。 提高了预测未来的准确性。增加了对自然世界的理解。更准确的预测和更好的未来事件预测。预测未来结果的概率更高。预测误差风险低。误差风险更低。基于数据优化预测。从前几年推断数据。基于过去经验预测过去几年。更好的预测准确性。使用更强大的机器可以做出更准确的预测。好处包括:- - 改进了对环境未来变化的预测。这可以降低未来行动出错的风险。- 提高了对未来趋势预测的可预测性。如果您无法预测未来的发展,您 |
您在这里使用了三个不同的提示,并生成了三段文本。该模型使用起来非常简单。您只需将分词器编码的提示以及注意力掩码传递给 `generate_text` 方法。注意力掩码由分词器提供,但本质上只是与输入形状相同的全一张量。
文本生成中的参数有哪些?
如果您查看 `generate_text` 方法,您会看到通过 `gen_kwargs` 传递了几个参数。其中一些最重要的参数是 `top_k`、`top_p` 和 `temperature`。您可以通过尝试不同的值来查看 `top_k` 和 `top_p` 的效果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
... generator = TextGenerator() # 采样效果示例 prompt = "科学家发现" # 使用 top-k 采样 top_k_text = generator.generate_text( prompt, top_k=10, top_p=1.0, max_length=50 ) print(f"Top-k 采样 (k=10):\n{top_k_text}\n") # 使用核 (top-p) 采样 nucleus_text = generator.generate_text( prompt, top_k=0, top_p=0.9, max_length=50 ) print(f"核采样 (p=0.9):\n{nucleus_text}\n") # 结合两者 combined_text = generator.generate_text( prompt, top_k=50, top_p=0.95, max_length=50 ) print(f"组合采样:\n{combined_text}\n") |
示例输出可能为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
Top-k 采样 (k=10) 科学家发现,只要分子不相互接触,该蛋白质就能够与受体结合。然后,科学家们利用这一点来研究蛋白质合成对身体自然免疫系统的影响。 The 核采样 (p=0.9) 科学家发现,空气中的氮、碳和氧都是碳原子。 “我们知道氮和碳在大气中非常微小,含量很少。但我们不知道这对整个地球意味着什么,”他说。 组合采样 科学家发现,阻止病毒传播的唯一方法是向体内引入少量细菌。 “我们想开发一种疫苗来阻止病毒进入我们的血液,”他说。 |
`top_k` 和 `top_p` 参数用于微调采样策略。要理解它,请记住模型为每个标记输出一个词汇表上的概率分布。有大量的标记。当然,您总是可以选择概率最高的标记,但您也可以选择一个随机标记,以便从相同的提示中生成不同的输出。这就是 GPT-2 使用的**多项式采样**算法。
`top_k` 参数将选择限制为 $k>0$ 个最有可能的标记。与其考虑词汇表中的数千个标记,不如将 `top_k` 设置为更易于处理的子集。
`top_p` 参数进一步缩短了选择范围。它只考虑那些累积概率达到 `top_p` 参数 $P$ 的标记。然后根据概率对生成的标记进行采样。
上面的代码演示了三种不同的采样方法。
- 第一个示例将 `top_k` 设置为一个小值,限制了选择。输出是集中的,但可能重复。
- 第二个示例通过将其设置为 0 来关闭 `top_k`。它设置 `top_p` 以使用核采样。采样池将移除低概率标记,提供更自然的变体。
- 第三个示例是一种组合方法,利用这两种策略以获得最佳结果。设置更大的 `top_k` 以允许更好的多样性,因此随后更大的 `top_p` 仍然可以提供高质量、自然的生成。
然而,一个令牌的概率是多少呢?这就是温度参数。我们来看另一个例子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
... generator = TextGenerator() # 温度效应示例 prompt = "机器人小心翼翼地" # 低温度(更集中) focused = generator.generate_text( prompt, temperature=0.3, max_length=50 ) print(f"低温度 (0.3):\n{focused}\n") # 中等温度(平衡) balanced = generator.generate_text( prompt, temperature=0.7, max_length=50 ) print(f"中等温度 (0.7):\n{balanced}\n") # 高温度(更具创造性) creative = generator.generate_text( prompt, temperature=1.0, max_length=50 ) print(f"高温度 (1.0):\n{creative}\n") |
请注意,所有三个示例都使用相同的提示。输出可能为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
低温度 (0.3) 机器人小心翼翼地将头部向左移动,然后机器人的头部向右移动。机器人随后回到正常位置。 下次你看到机器人时,你会看到它们向不同的方向移动。它们 中等温度 (0.7) 机器人小心翼翼地移动着握着物体的人的手臂和腿。然而,机器人仍然一动不动, 机器人无法尝试移动手臂或腿。 那个人的身体 高温度 (1.0) 机器人小心翼翼地穿过机器人,下一刻,它又回到了控制室。他起身从 地板上走过,一秒钟后,他被撞伤了。然后我们看到同一个机器人的第三部分 |
那么温度有什么影响呢?您可以看到
- 0.3 的低温度会产生更集中和确定性的输出。输出很无聊。使其适用于需要准确性的任务。
- 0.7 的中等温度在创造性和连贯性之间取得了平衡。
- 1.0 的高温度生成更多样化和创造性的文本。每个示例都使用相同的 max_length 进行公平比较。
在幕后,温度是 softmax 函数中的一个参数,该函数应用于模型的输出以确定输出标记。softmax 函数是:
$$
s(x_j) = \frac{e^{x_j/T}}{\sum_{i=1}^{V} e^{x_i/T}}
$$
其中 $T$ 是温度参数,$V$ 是词汇量大小。用 $T$ 缩放模型输出 $x_1,\dots,x_V$ 会改变标记的相对概率。高温度使概率更均匀,让不太可能的标记更有可能被选择。低温度使概率更集中在最高概率标记上,因此输出更具确定性。
批处理和填充
上面的代码适用于单个提示。但是,在实践中,您可能需要为多个提示生成文本。以下代码显示了如何有效地处理多个提示的生成。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer class BatchGenerator: def __init__(self, model_name="gpt2"): """使用预训练模型初始化文本生成器。 参数 model_name (str): 要使用的预训练模型的名称。 可以是:"gpt2"、"gpt2-medium"、"gpt2-large" """ self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) self.tokenizer.add_special_tokens({'pad_token': self.tokenizer.eos_token}) self.model = GPT2LMHeadModel.from_pretrained(model_name) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) def generate_batch(self, prompts, **kwargs): """高效地为多个提示生成文本。 参数 prompts (list): 输入提示列表 batch_size (int): 每次处理的提示数量 **kwargs: 其他生成参数 Returns list: 每个提示的生成文本 """ inputs = self.tokenizer(prompts, padding=True, padding_side="left", return_tensors="pt") outputs = self.model.generate( inputs["input_ids"].to(self.device), attention_mask=inputs["attention_mask"].to(self.device), **kwargs ) results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return results # 批量生成示例用法 batch_generator = BatchGenerator() prompts = [ "人工智能的未来", "太空探索将", "在未来十年", "气候变化已" ] generated_texts = batch_generator.generate_batch( prompts, max_length=100, temperature=0.7, do_sample=True, ) for prompt, text in zip(prompts, generated_texts): print(f"\n提示: {prompt}") print(f"生成文本: {text}") |
输出可能为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
提示: 人工智能的未来 生成文本:人工智能的未来充满不确定性,很难预测它将如何发展,”东京东北大学人工智能与机器学习中心主任松尾由纪教授说。 “但即使人工智能不是人类安全的唯一威胁,它也将是最重要的威胁之一,这将改变我们对机器人未来发展的看法。” 本文经授权转载,首次发表于5月 提示词:太空探索将 生成文本:太空探索也将是一个挑战,航天局的航天飞机编队正在接近其到2030年实现15亿人容量的最终目标。 虽然航天飞机能够将宇航员运送到国际空间站并从那里返回,但美国宇航局的新型航天飞机(首次载人登月任务)目前已签订为期六年的合同。该机构还在开发一个价值100亿美元的卫星轨道推进系统,该系统将使载人飞船能够 提示词:在未来十年 生成文本:根据美国经济顾问委员会的数据,在未来十年,美国收入最高的10%人群的平均工资从每小时12.50美元上涨到16.50美元。与此同时,美国收入最高的20%人群的年收入接近169亿美元。 收入最高的1%是大多数美国人的主要收入来源,中高收入群体的收入几乎是收入最低的40%人群的两倍。 The 提示词:气候变化已 生成文本:气候变化降低了发生自然气候变化的可能性。 事实上,气候变化变得更加频繁和严重的可能性极高。因此,任何旨在促进或减少极端天气事件发生的政策,都极有可能导致包括美国在内的极端天气事件。 飓风、洪水和降雪等极端天气事件的风险是发生发展风险的两倍多。 |
BatchGenerator
的实现做了一些细微的改动。 generate_batch
方法接受一个提示列表,并将其他参数传递给模型的 generate
方法。最重要的是,它将提示填充到相同的长度,然后为批次中的每个提示生成文本。结果以与提示相同的顺序返回。
GPT-2 模型经过训练可以处理批量输入。但要将输入以张量形式呈现,所有提示都需要填充到相同的长度。分词器可以轻松处理批量输入。但 GPT-2 模型没有指定填充标记应该是什么。因此,您需要使用 add_special_tokens()
函数来指定它。上面的代码使用 EOS 标记。但实际上,您可以使用任何标记,因为注意力掩码将强制模型忽略它。
优化生成结果的技巧
您知道如何使用 GPT-2 模型生成文本。但您应该对输出有什么期望呢?这确实是一个取决于任务的问题。但这里有一些技巧可以帮助您获得更好的结果。
首先是提示工程。您需要在提示中具体明确,才能获得高质量的输出。模糊的词语或短语可能导致模糊的输出,因此您应该具体、简洁、精确。您还可以包含相关上下文,以帮助模型理解任务。
此外,您还可以调整参数以获得更好的结果。根据任务,您可能希望输出更专注或更具创造性。您可以调整温度参数来控制输出的随机性。您还可以调整 temperature
、top_k
和 top_p
参数来控制输出的多样性。输出生成是自回归的。您可以设置 max_length
参数来通过权衡速度来控制输出的长度。
最后,上面的代码不是容错的。在生产环境中,您需要实现适当的错误处理、设置合理的超时、监控内存使用情况以及实现速率限制。
进一步阅读
下面是一些进一步的阅读材料,可以帮助您更好地理解使用 GPT-2 模型进行文本生成。
- 语言模型是无监督多任务学习者,作者:Alec Radford,Jeffrey Wu 等人 (2019)
- 对比搜索是神经文本生成所需的,作者:Yixuan Su 和 Nigel Collier (2022)
- 使用 Transformer 中的对比搜索生成人类水平的文本
总结
在本教程中,您学习了如何使用 GPT-2 生成文本,并使用 Transformer 库通过几行代码构建实际应用程序。特别是,您学习了
- 如何使用 GPT-2 实现文本生成
- 如何针对不同用例控制生成参数
- 如何实现批处理以提高效率
- 最佳实践和要避免的常见陷阱
暂无评论。