Transformer 模型是当今 NLP 任务的标准模型。几乎所有 NLP 任务都涉及文本生成,但它不是模型的直接输出。您可能期望模型能帮助您生成连贯且与上下文相关的文本。虽然这部分与模型的质量有关,但生成参数在生成文本的质量中也起着至关重要的作用。
在本文中,您将探索控制 Transformer 模型中文本生成的关键参数。您将了解这些参数如何影响生成文本的质量以及如何针对不同应用调整它们。特别是,您将学到
- 控制 Transformer 模型中文本生成的关键参数
- 不同的解码策略
- 如何控制生成文本的创造性和连贯性
- 如何针对特定应用微调生成参数
通过我的书籍《Hugging Face Transformers中的NLP》,快速启动您的项目。它提供了带有工作代码的自学教程。
让我们开始吧!

理解 Transformers 中的文本生成参数
图片作者:Anton Klyuchnikov。部分权利保留。
概述
本文分为七部分;它们是
- 核心文本生成参数
- 实验温度
- Top-K 和 Top-P 采样
- 控制重复
- 贪婪解码和采样
- 特定应用的参数
- Beam Search 和多序列生成
核心文本生成参数
我们以 GPT-2 模型为例。它是一个小型 Transformer 模型,不需要大量的计算资源,但仍能生成高质量的文本。使用 GPT-2 模型生成文本的简单示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer # 创建模型和分词器 tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") # 将输入提示分词为 ID 序列 prompt = "Artificial intelligence is" inputs = tokenizer(prompt, return_tensors="pt") # 生成 token id 序列的输出 output = model.generate( **inputs, max_length=50, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) # 将 token id 转换为文本字符串 generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(f"Prompt: {prompt}") print("Generated Text:") print(generated_text) |
如果运行此代码,您可能会看到
1 2 3 4 5 6 |
Prompt: Artificial intelligence is Generated Text Artificial intelligence is used in the production of technology, the delivery of which is determined by technological change. For example, an autonomous car can change its steering wheel to help avoid driving traffic. In the case of artificial intelligence, this can change what consumers |
您只提供了一个三词提示,模型就生成了很长一段文本。这不是一次性生成的,而是模型通过迭代过程被调用了多次。
您可以看到 generate()
函数中使用了许多参数。您使用的第一个是 max_length
。顾名思义,它以 token 的数量控制生成文本的长度。通常,模型一次生成一个 token,并将提示作为上下文。然后,将新生成的 token 追加到提示中,生成下一个 token。因此,您想要的生成文本越长,生成它所需的时间就越长。请注意,这里关注的是 token 而不是单词,因为您为 GPT-2 模型使用了子词分词器。一个 token 可能只是一个子词单元,而不是一个完整的单词。
然而,模型并不是具体生成任何单个 token。相反,它生成的是“logits”,这是一个包含下一个 token 概率的向量。Logit 是一个长向量,其长度等于词汇表的大小。给定它是所有可能的“下一个 token”上的概率分布,您可以选择概率最高的 token(当您设置 do_sample=False
时),或者任何其他非零概率的 token(当您设置 do_sample=True
时)。所有其他参数都是为此目的而存在的。
temperature
参数会扭曲概率分布。较低的温度会强调最可能的 token,而较高的温度会减小最有可能和最不可能的 token 之间的差异。默认温度为 1.0,并且它应该是一个正值。然后,top_k
参数仅选择 top $k$ 个 token,而不是整个词汇表。然后重新计算概率,使其总和为 1。接下来,如果设置了 top_p
,则此 $k$ 个 token 的集合将进一步过滤,以保留总概率为 $p$ 的 top token。最后,使用此 token 集合来采样下一个 token,此过程称为核采样。
请记住,您一次生成一个 token 序列。很可能会在每个步骤中都看到相同的 token,并且您可能会在序列中看到相同的 token。这通常不是您想要的,因此当您再次看到这些 token 时,您可能需要降低它们的概率。这就是 repetition_penalty
参数的作用。
实验温度
知道了各种参数的作用后,让我们看看调整某些参数时输出会发生什么变化。
温度参数对生成文本的创造性和随机性有显著影响。您可以通过以下示例看到它的效果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") prompt = "The future of artificial intelligence is" inputs = tokenizer(prompt, return_tensors="pt") # 使用不同的温度值生成文本 temperatures = [0.2, 0.5, 1.0, 1.5] print(f"Prompt: {prompt}") for temp in temperatures: print() print(f"Temperature: {temp}") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=temp, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) |
运行此代码时,您可能会看到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
Prompt: The future of artificial intelligence is Temperature: 0.2 Generated Text The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain. The future Temperature: 0.5 Generated Text The future of artificial intelligence is uncertain. "There is a lot of work to be done on this," said Eric Schmitt, a professor of computer science and engineering at the University of California, Berkeley. "We're looking for a way to make AI more like computers. We need to take a step back and look at how we think about it and how we interact with it." Schmitt said he's confident that artificial intelligence will eventually be able to do more than Temperature: 1.0 Generated Text The future of artificial intelligence is not yet clear, however." "Is the process that we are trying to do through computer vision and the ability to look at a person at multiple points without any loss of intelligence due to not seeing a person at multiple points?" asked Richard. "I also think the people who are doing this research are extremely interesting to me due to being able to see humans at a range of different points in time. In particular, they've shown how to do a pretty complex Temperature: 1.5 Generated Text The future of artificial intelligence is an era to remember as much as Google in search results, particularly ones not supported by much else for some years -- and it might look like the search giant is now just as good without artificial intelligence. [Graphic image from Shutterstock] |
在较低的温度下(例如 0.2),文本变得更加聚焦和确定,通常会遵循常见的短语和传统观念。您还会发现它不断重复相同的句子,因为概率集中在少数几个 token 上,限制了多样性。这可以通过使用下面章节介绍的重复惩罚参数来解决。
在中等温度下(例如 0.5 到 1.0),文本在连贯性和创造性之间取得了很好的平衡。生成的文本可能不准确,但语言自然。
在较高的温度下(例如 1.5),文本变得更加随机和富有创造性,但也可能不太连贯,有时甚至不合逻辑。正如上面的示例所示,语言可能难以理解。
选择正确的温度取决于您的应用。如果您正在创建一个用于代码补全或写作的助手,较低的温度通常更好。对于创意写作或头脑风暴,较高的温度可以产生更多样化和有趣的结果。
Top-K 和 Top-P 采样
核采样参数控制模型选择下一个 token 的灵活性。您应该调整 top_k
参数还是 top_p
参数?让我们通过一个示例来了解它们的影响:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") prompt = "The best way to learn programming is" inputs = tokenizer(prompt, return_tensors="pt") # 使用不同的 top_k 值生成文本 top_k_values = [5, 20, 50] print(f"Prompt: {prompt}") for top_k in top_k_values: print() print(f"Top-K = {top_k}") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=top_k, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) # 使用不同的 top_p 值生成文本 top_p_values = [0.5, 0.7, 0.9] for top_p in top_p_values: print() print(f"Top-P = {top_p}") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=0, top_p=top_p, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) |
运行此代码时,您可能会看到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
Prompt: The best way to learn programming is Top-K = 5 Generated Text The best way to learn programming is to be able to learn the basics in a very short amount of time, and then learn to use them effectively and quickly. If you want to be a successful programmer in this way, you should learn to use the techniques in the above video to learn the basics of programming. If you want to learn to code more effectively, you can also get more experienced programmers by doing the following Learning to Code Learning to code is very Top-K = 20 Generated Text The best way to learn programming is to learn it. In order to get started with Ruby you're going to have to make a few mistakes, some of them can be fairly obvious. First of all, you're going to have to write a function that takes in a value. What this means is that you're going to make a new instance of the Ruby function. You can read more about this in Part 1 of this course, or just try it out from the REPL. Top-K = 50 Generated Text The best way to learn programming is to become familiar with the language and the software. One of the first and most common forms of programming is to create, modify, and distribute code. However, there are very few programming libraries that can provide us with all that we need. The following sample programming program uses some of the above, but does not show the best way to learn programming. It was written in Java and in C or C++. The original source code is Top-P = 0.5 Generated Text The best way to learn programming is to be able to create a tool for you. That's what I do. That's why I'm here today. I'm here to talk about the basics of programming, and I'm going to tell you how to learn programming. I'm here to talk about learning programming. It's easy to forget that you don't have to know how to program. It's easy to forget that you don't have to know how Top-P = 0.7 Generated Text The best way to learn programming is to practice programming. Learn the principles of programming by observing and performing exercises. I used to work in a world of knowledge which included all sorts of things, and was able to catch up on them and understand them from their perspective. For instance, I learned to sit up straight and do five squats. Then, I would have to practice some type of overhead training. I would try to learn the best technique and add that to my repertoire. What Top-P = 0.9 Generated Text The best way to learn programming is to become a good hacker. Don't use any programming tools. Just a regular dot-com user, an occasional coding learner, and stick with it. — Victoria E. Nichols |
您可以看到,当 k 值较小时,例如 5,模型可供选择的选项就更少,从而产生更可预测的文本。在极端情况下,当 k=1 时,模型始终选择概率最高的那个 token,这是一种贪婪解码,通常会产生糟糕的输出。当 k 较大时,例如 50,模型可供选择的选项就更多,从而产生更多样化的文本。
同样,对于 top_p
参数,较小的 p 意味着模型从较少的高概率 token 集合中选择,从而产生更集中的文本。当 p 较大时,例如 0.9,模型选择范围更广,可能导致更多样化的文本。但是,对于给定的 p,您可能选择的选项数量不是固定的。它取决于模型预测的概率分布。当模型对下一个 token 非常有信心时(例如,受某些语法规则限制),只允许非常少量的 token。这种自适应性也是为什么核采样通常比 top-k 采样更受青睐。
控制重复
重复是文本生成中的一个常见问题。repetition_penalty
参数通过惩罚已在生成文本中出现的 token 来帮助解决此问题。让我们看看它是如何工作的:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") prompt = "Once upon a time, there was a" inputs = tokenizer(prompt, return_tensors="pt") # 使用不同的重复惩罚值生成文本 penalties = [1.0, 1.2, 1.5, 2.0] print(f"Prompt: {prompt}") for penalty in penalties: print() print(f"Repetition penalty: {penalty}") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=0.3, top_k=50, top_p=1.0, repetition_penalty=penalty, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) |
运行此代码时,您可能会看到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
Prompt: Once upon a time, there was a Repetition penalty: 1.0 Generated Text Once upon a time, there was a great deal of confusion about what was going on. The first thing that came to mind was the fact that the government had already been in place for a long time, and that the government had been in place for a long time. And it was clear that the government had been in place for a long time. And it was clear that the government had been in place for a long time. And it was clear that the government had been in place for a long Repetition penalty: 1.2 Generated Text Once upon a time, there was a great deal of talk about the possibility that this would be an opportunity for us to see more and better things in our lives. We had been talking on Facebook all day long with people who were interested in what we could do next or how they might help others find their own way out." "We've always wanted to make sure everyone has access," he continued; "but it's not like you can just go into your room at night looking around without seeing Repetition penalty: 1.5 Generated Text Once upon a time, there was a man who had been called to the service of God. He came and said: "I am an apostle from Jerusalem." And he answered him with great joy, saying that it is not possible for me now in this life without having received Jesus Christ as our Lord; but I will be saved through Him alone because my Father has sent Me into all things by His Holy Spirit (John 1). The Christian Church teaches us how much more than any other religion can Repetition penalty: 2.0 Generated Text Once upon a time, there was a man who had been sent to the city of Nausicaa by his father. The king's son and brother were killed in battle at that place; but when he returned with them they found him dead on their way back from war-time.[1] The King gave orders for an expedition against this strange creature called "the Gorgon," which came out into space during one night after it attacked Earth[2]. It is said that these creatures |
在上面的代码中,我们将温度设置为 0.3 以强调重复惩罚的效果。当惩罚值较低,为 1.0 时,您可以看到模型一遍又一遍地重复相同的短语。当其他设置将候选 token 限制在一个小的子集时,模型很容易陷入循环。但在高惩罚值下,例如 2.0 或更高,模型会强烈避免重复,这有时会导致文本不够自然。中等惩罚值(例如 1.2 到 1.5)通常是保持连贯性的一个不错的折衷。
毕竟,generate()
函数中设置参数的目的是使文本流畅自然。您可能需要通过实验来调整这些参数,以查看哪个参数最适合您的特定应用。请注意,这些参数可能取决于您正在使用的模型,因为每个模型生成 token 的分布可能不同。
贪婪解码和采样
do_sample 参数控制模型是使用采样(概率性选择 token)还是贪婪解码(始终选择最可能的 token)。让我们比较一下这两种方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") prompt = "The secret to happiness is" inputs = tokenizer(prompt, return_tensors="pt") # 使用贪婪解码与采样生成文本 print(f"Prompt: {prompt}\n") print("Greedy Decoding (do_sample=False):") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) print() print("Sampling (do_sample=True):") output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("Generated Text:") print(generated_text) |
尝试多次运行此代码并观察输出。您会注意到,贪婪解码的输出始终相同,而采样输出每次都不同。对于固定的提示,贪婪解码是确定性的。模型生成概率分布,并选择最可能的 token。不涉及随机性。输出更容易出现重复且无用。
采样输出是随机的,因为输出 token 是根据模型的预测概率分布选择的。随机性允许模型生成更具多样性和创造性的文本,同时只要其他生成参数设置正确,输出仍然是连贯的。对于采样输出,您可以将 num_return_sequences
设置为大于 1 的数字,以便为同一个提示并行生成多个序列。此参数对于贪婪解码无意义。
特定应用的参数
对于特定的应用,应该设置哪些参数值?没有确切的答案。您肯定需要进行一些实验来找到最佳组合。但您可以将以下内容作为起点:
- 事实生成
- 较低的
temperature
(0.2 到 0.4)以获得更确定的输出 - 适中的
top_p
(0.8 到 0.9)以过滤掉不太可能的 token - 较高的
repetition_penalty
(1.2 到 1.5)以避免重复陈述
- 较低的
- 创意写作
- 较高的
temperature
(1.0 到 1.3)以获得更具创造性和多样性的输出 - 较高的
top_p
(0.9 到 0.95)以允许更多可能性 - 较低的
repetition_penalty
(1.0 到 1.1)以允许一些风格上的重复
- 较高的
- 代码生成
- 较低的
temperature
(0.1 到 0.3)以获得更精确、更正确的代码 - 较低的
top_p
(0.7 到 0.8)以专注于最可能的 token - 较高的
repetition_penalty
(1.3 到 1.5)以避免冗余代码
- 较低的
- 对话生成
- 适中的
temperature
(0.6 到 0.8)以获得自然但集中的响应 - 适中的
top_p
(0.9)以在创造性和连贯性之间取得良好平衡 - 适中的
repetition_penalty
(1.2)以避免重复短语
- 适中的
请记住,语言模型并非完美的预言家。它可能会犯错误。上述参数是为了帮助您将生成过程适配到预期的输出风格,而不是保证正确性。您获得的输出可能包含错误。
Beam Search 和多序列生成
在上面的示例中,生成过程是自回归的。这是一个迭代过程,一次生成一个 token。
由于每一步通过采样生成一个 token,没有什么能阻止您一次生成多个 token。如果您这样做,您将为单个输入提示生成多个输出序列。理论上,如果您在每一步生成 $k$ 个 token 并将返回长度设置为 $n$,您将生成 $k^n$ 个序列。这是一个很大的数字,您可能希望将其限制在几个以内。
生成多个序列的第一种方法是将 num_return_sequences
设置为数字 $k$。您在第一步生成 $k$ 个 token。然后为每个 token 完成序列。这实际上是在生成过程中将提示复制了 $k$ 次。
第二种方法是使用束搜索 (beam search)。这是一种更复杂的生成多个序列的方法。它会跟踪最有希望的序列,并并行探索它们。它不会生成 $k^n$ 个序列来压倒内存,而是在每一步只保留 $k$ 个最佳序列。每个 token 生成步骤将暂时扩展此集合,然后将其修剪回 $k$ 个最佳序列。
要使用束搜索,您需要将 num_beams
设置为数字 $k$。每一步都会将 $k$ 个序列中的每个序列扩展一个 token,从而产生 $k^2$ 个序列,然后选择最佳的 $k$ 个序列以继续下一步。您还可以设置 early_stopping=True
,以便在序列结束时停止生成。您还应该设置 num_return_sequences
来限制最终输出的选择。
序列的选择通常基于序列中 token 的累积概率。但是,您也可以通过其他标准来影响选择,例如添加长度惩罚或避免重复 n-gram。以下是使用束搜索的示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") prompt = "The key to successful machine learning is" inputs = tokenizer(prompt, return_tensors="pt") # 使用贪婪解码与采样生成文本 print(f"Prompt: {prompt}\n") outputs = model.generate( **inputs, num_beams=5, # 使用的束的数量 early_stopping=True, # 当所有束都完成后停止 no_repeat_ngram_size=2, # 避免重复 n-gram num_return_sequences=3, # 返回多个序列 max_length=100, temperature=1.5, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) for idx, output in enumerate(outputs): generated_text = tokenizer.decode(output, skip_special_tokens=True) print(f"Generated Text ({idx+1}):") print(generated_text) |
您可以添加更多生成参数(例如 length_penalty
)来控制生成过程。上面的示例设置了更高的温度来突出束搜索的输出。如果您运行此代码,您可能会看到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
Prompt: The key to successful machine learning is Generated Text (1) The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines. So, let's take a step back and look at how we can learn. Here's a list of the tools we use to help us do that. We're going to go over a few of them here and give you a general idea of what they are and how you can use them to create Generated Text (2) The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines. So, let's take a step back and look at how we can learn. Here's a list of the tools we use to help us do that. We're going to go over a few of them here and give you a general idea of what they are and how you can use them and what Generated Text (3) The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines. So, let's take a step back and look at how we can learn. Here's a list of the tools we use to help us do that. We're going to go over a few of them here and give you a general idea of what they are and how they work. You can use |
输出序列的数量仍由 num_return_sequences
控制,但生成它们的流程使用了束搜索算法。从输出中很难确定是否使用了束搜索。一个迹象是,束搜索的输出不像仅仅设置 num_return_sequences
那样多样化,因为生成了更多的序列,并且选择了累积概率更高的序列。这种过滤确实降低了输出的多样性。
进一步阅读
以下是一些您可能觉得有用的进一步阅读资料:
总结
在这篇文章中,您看到了 generate()
函数中的许多参数如何用于控制生成过程。您可以调整这些参数,使输出符合您应用程序的预期风格。具体来说,您学到了
- 如何使用温度来控制输出的概率分布
- 如何使用 top-k 和 top-p 来控制输出的多样性
- 如何使用重复惩罚、束搜索和贪婪解码来控制输出
通过理解和调整这些参数,您可以针对不同的应用优化文本生成,从事实写作到创意故事、代码生成和对话系统。
暂无评论。