语言模型需要理解序列中词语之间的关系,无论它们之间的距离如何。本文探讨了注意力机制如何实现这一能力,以及它们在现代语言模型中的各种实现方式。
让我们开始吧。

多头注意力(Multi-Head Attention)和分组查询注意力(Grouped-Query Attention)的初步介绍
图片作者:Ye Min Htet。保留部分权利。
概述
这篇博文分为三部分;它们是:
- 为何需要注意力机制
- 注意力操作
- 多头注意力(MHA)
- 分组查询注意力(GQA)和多查询注意力(MQA)
为何需要注意力机制
传统神经网络在处理序列中的长距离依赖关系时遇到困难。考虑一下这句话:
“The animal didn’t cross the road because it was too tired.”(动物没有过马路,因为它太累了。)
为了理解“it”指的是什么,模型需要回顾“animal”,这是一个跨越多个词语的关系。
另一个例子是翻译句子“I want to try on a suit that I saw in a shop that’s across the street from the hotel”(我想试穿我在酒店对面商店里看到的一套西装)。这是一个著名的句子,它展示了不同语言之间的差异,如下图所示:
如果你将这个句子从英语翻译成法语,你可能可以按照原始顺序逐一匹配单词。然而,这并非一帆风顺,因为英语中的“want”在法语中可能是“veux”、“voulons”或“voulez”,具体取决于它与“I”、“we”或“you”的关联。因此,翻译模型需要**关注**“Je”(法语中“I”的对应词,以确定动词形式)和“want”(以确定动词),才能找到正确的翻译。
翻译成日语甚至更具挑战性,因为日语使用主宾谓(SOV)语序。当模型看到“I want…”时,它需要等到句末才能确定宾语。在模型生成“私は”(日语中“I”的对应词)和“ホテル”(日语中“hotel”的对应词)之后,紧接的词并不会影响翻译,但它必须与英语句子的最后一个词相匹配。
注意力机制的目的是帮助模型聚焦于序列中的相关部分,同时忽略其余部分。
注意力操作
注意力机制是为了解决翻译模型中长距离依赖问题而发明的。在翻译语境中理解它最容易。
让我们考虑已经处理完上一节中的英语句子。现在模型正在逐个生成法语单词。在第一个单词“Je”之后,模型需要决定第二个单词应该是什么。
首先,我们将目前生成的法语单词定义为“查询”(query)序列,将已处理的英语句子定义为“键”(key)序列。注意力操作首先计算注意力分数:
$$
\frac{QK^T}{\sqrt{d}}
$$
这是一个矩阵,其中元素 $(i,j)$ 是第 $i$ 个法语单词和第 $j$ 个英语单词之间对齐的分数。分数越高,这两个单词**对齐**得越紧密。对齐并不意味着等价,而是指示下一步应该关注什么。在这个例子中,对齐应该关注英语句子中的“want”,它应该具有最高的注意力分数。上述公式中的 $\sqrt{d}$ 部分是一个常数,用于缩放注意力分数。你现在可以忽略它的作用。
之后,注意力分数通过 softmax 函数进行归一化:
$$
\text{softmax}\Big(\frac{QK^T}{\sqrt{d}}\Big)
$$
Softmax 函数将注意力分数归一化,使得每一行之和为1。这样做的原因在下一步变得清晰,即计算值的加权和:
$$
O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$$
$V$ 是“值”(value)序列。在这个例子中,你可以将其视为原始句子中每个英语单词的法语翻译。这只是一个简单的翻译,没有考虑任何上下文。在“want”的情况下,它可能是“vouloir”,而不是“veux”。
上面的方程计算了值的加权和。因为我们希望输出 $O$ 是一个词,所以权重之和应该为 1,以防止输出变得过大或过小。为什么要使用加权和而不是选择单个词的翻译呢?因为法语中的一个词可能与英语句子中的多个词相关。在这个例子中,你可能需要 90% 的“want”(“vouloir”)和 10% 的“I”(“je”)来生成正确的动词形式“veux”作为法语的下一个输出词。
上述解释中缺少的一点是,英语句子中的词如何变成法语词,以便你可以在“值”序列中使用它们。这实际上是由一个“投影矩阵”生成的。完整的方程是:
$$
\begin{aligned}
Q &= F W^Q \\
K &= E W^K \\
V &= E W^V \\
X_O &= \Big(\text{softmax}\big(\frac{QK^T}{\sqrt{d}}\big)V\Big)W^O
\end{aligned}
$$
其中 $E$ 和 $F$ 分别是英语单词序列和部分法语单词序列。 $X_O$ 是输出,即法语的下一个单词。 $W^Q$、$W^K$ 和 $W^V$ 是将序列转换到不同空间的投影矩阵。 $W^O$ 是将输出 $O$ 转换回原始空间 $X_O$ 的投影矩阵。
多头注意力(MHA)
上一节的描述只是注意力操作的高层视图。在翻译模型中,输入到注意力的序列实际上不是单词序列,而是词嵌入向量序列。投影矩阵将嵌入向量转换到不同的空间,然后在这个转换后的空间中应用注意力操作。
我们如何确定投影矩阵?这实际上很困难。原因是每个词都可以被转换成多个不同的空间。例如,存在一个“意义空间”来表示词的意义。也可能存在一个词性空间来指示一个词是名词、动词还是形容词。你不需要只选择一个。没有什么可以阻止你并行使用多个空间。
这就是“多头注意力”(MHA)被引入的原因。你使用的不是一组,而是多组投影矩阵,每组都执行自己的注意力操作。然后将输出连接起来以产生最终输出。
因为你有多个相互独立的注意力头,所以你可以并行运行它们。原始的 Transformer 架构使用 8 个注意力头,并被发现在翻译任务中表现良好。
以方程形式,MHA 可以表示为:
$$
\begin{aligned}
\text{Attention}(Q, K, V) &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \\
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_i, X_VW^V_i) \\
\text{MultiHead}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$
在 PyTorch 中,你可以通过实现上述方程来创建自己的 MHA 层:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x): batch_size = x.size(0) seq_length = x.size(1) # Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores, optionally add attention mask to the score scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = F.softmax(scores, dim=-1) # optional: attn_weights = F.dropout(attn_weights, p=0.2) # Apply attention weights to values context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model) return self.out_proj(context) |
注意力机制的一个显著特点是,我们通常保持输入和输出序列中的向量维度相同。对于简单的注意力操作来说,这通常不是问题。但对于 MHA,存在多个注意力头,并且输出将沿着向量维度进行拼接。因此,每个头应该在减小的维度 head_dim = d_model // num_heads
中操作,以便进行拼接。
在上面的构造函数中,输入投影矩阵被定义为 q_proj
、k_proj
和 v_proj
。输出投影矩阵被定义为 out_proj
。
在 forward()
函数中,输入 x
通过输入投影矩阵投影到 q
、k
和 v
。输入 x
的形状为 (batch_size, seq_length, d_model)
。投影后形状保持不变,但随后被重塑并转置为 (batch_size, num_heads, seq_length, head_dim)
。计算 scores
的 matmul()
对齐 head_dim
维度(最后一个轴),结果注意力分数的形状为 (batch_size, num_heads, seq_length, seq_length)
。然后沿最后一个轴应用 softmax,使该轴上的和为 1。
如果您需要对注意力机制应用**掩码**,例如解码器专用模型中常用的**因果掩码**,您应该在应用 softmax 之前将其应用于 scores
张量。有些模型在 softmax 之后对注意力权重应用**dropout**。据信这有助于模型变得更健壮。
注意力权重随后与 `v` 相乘,结果被转置回形状 `(batch_size, seq_length, num_heads, head_dim)`。`contiguous()` 用于使结果在内存中连续,以便每个头的向量可以使用 `view()` 拼接回原始形状 `(batch_size, seq_length, d_model)`。然后这个张量由 `out_proj` 投影,并作为注意力操作的输出。
上述实现是**自注意力**,因为类中的 `forward()` 函数使用相同的输入 `x` 来创建 `q`、`k` 和 `v`。在**交叉注意力**中,一个输入序列用于 `q`,另一个输入序列用于 `k` 和 `v`。
请注意,在 PyTorch 中,`view()` 是更改张量形状比 `reshape()` 更快的方式,但它要求张量形状是连续的。如果您对某些轴进行了转置,张量将不连续。您应该在张量上调用 `contiguous()` 来移动内存,使其再次连续。在创建 `context` 张量的行中就是这种情况。
实际上,在 PyTorch 中,上述类已经作为 `torch.nn.MultiheadAttention` 实现。您应该使用它。
分组查询注意力(GQA)和多查询注意力(MQA)
多头注意力(MHA)是最强大的注意力机制,但它涉及大量的计算。有多种方法可以降低计算成本。分组查询注意力(GQA)是最受欢迎的一种。
GQA 通过在查询头**组**之间共享键和值投影来降低 MHA 的计算成本:
$$
\begin{aligned}
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_{g(i)}, X_VW^V_{g(i)}) \\
\text{GQA}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$
与MHA相比,GQA对同一组 $g(i)$ 中的所有查询头使用相同的投影矩阵 $W^K_{g(i)}$ 和 $W^V_{g(i)}$。常见的分组方式是均匀分割头部,例如:
$$
\begin{aligned}
g(i) &= \left\lfloor \frac{i}{m} \right\rfloor \\
\therefore\; 0 &= g(0) = g(1) = \cdots = g(m-1) \\
1 &= g(m) = g(m+1) = \cdots = g(2m-1) \\
\vdots \\
\end{aligned}
$$
将上述代码示例修改为 GQA 很容易:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch import torch.nn as nn import torch.nn.functional as F import math class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_heads, num_groups): super().__init__() self.d_model = d_model self.num_heads = num_heads # 查询头的数量 self.num_groups = num_groups self.group_size = num_heads // num_groups self.head_dim = d_model // num_heads self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim) self.k_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.v_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.out_proj = nn.Linear(self.num_heads * self.head_dim, d_model) def forward(self, x): batch_size = x.size(0) seq_length = x.size(1) # Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) # Expand k and v to match the number of query heads k = k.repeat_interleave(self.group_size, dim=1) v = v.repeat_interleave(self.group_size, dim=1) # Compute attention scores scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = F.softmax(scores, dim=-1) # optional: attn_weights = F.dropout(attn_weights, p=0.2) # Apply attention weights to values context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model) return self.out_proj(context) |
与 MHA 相比,投影矩阵 k_proj
和 v_proj
不同。具体来说,这些投影矩阵更小,因此矩阵乘法计算速度更快。
由于投影矩阵 k_proj
和 q_proj
的形状不同,因此 q
和 k
之间的乘法是不可能的。因此,您需要使用 repeat_interleave()
将 k
扩展到与 q
相同的形状。这仅在 num_heads
可被 num_groups
整除时才有效。出于同样的原因,您还需要使用 repeat_interleave()
扩展 v
。
或者,您可以使用 PyTorch 内置的 `scaled_dot_product_attention()` 函数来简化上述实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch import torch.nn as nn import torch.nn.functional as F import math class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_heads, num_groups): super().__init__() self.d_model = d_model self.num_heads = num_heads # 查询头的数量 self.num_groups = num_groups self.group_size = num_heads // num_groups self.head_dim = d_model // num_heads self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim) self.k_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.v_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.out_proj = nn.Linear(self.num_heads * self.head_dim, d_model) def forward(self, x): batch_size = x.size(0) seq_length = x.size(1) # Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) # 使用 PyTorch 内置函数计算注意力分数 attn_output = F.scaled_dot_product_attention(q, k, v, enable_gqa=True) # 输出投影 context = attn_output.transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model) return self.out_proj(context) |
请注意,调用 `scaled_dot_product_attention()` 并设置 `enable_gqa=True` 就可以替换对 `repeat_interleave()`、`matmul()` 和 `softmax()` 的调用。
研究发现,GQA 可以在对模型质量影响最小的情况下,降低内存使用和计算时间。
如果在 GQA 中将分组数设置为 1,它就变成了**多查询注意力**(MQA)。但如果将分组数设置为与查询头数相同,它就会退化为多头注意力。
进一步阅读
以下是一些与该主题相关的论文:
总结
在这篇文章中,您了解了语言模型中的注意力机制。特别是,您了解了:
- 为什么注意力对于捕捉序列中的关系至关重要
- 多头注意力如何实现不同类型的关系
- 分组查询注意力如何平衡效率和性能
暂无评论。