多头注意力(Multi-Head Attention)和分组查询注意力(Grouped-Query Attention)的初步介绍

语言模型需要理解序列中词语之间的关系,无论它们之间的距离如何。本文探讨了注意力机制如何实现这一能力,以及它们在现代语言模型中的各种实现方式。

让我们开始吧。

多头注意力(Multi-Head Attention)和分组查询注意力(Grouped-Query Attention)的初步介绍
图片作者:Ye Min Htet。保留部分权利。

概述

这篇博文分为三部分;它们是:

  • 为何需要注意力机制
  • 注意力操作
  • 多头注意力(MHA)
  • 分组查询注意力(GQA)和多查询注意力(MQA)

为何需要注意力机制

传统神经网络在处理序列中的长距离依赖关系时遇到困难。考虑一下这句话:

“The animal didn’t cross the road because it was too tired.”(动物没有过马路,因为它太累了。)

为了理解“it”指的是什么,模型需要回顾“animal”,这是一个跨越多个词语的关系。

另一个例子是翻译句子“I want to try on a suit that I saw in a shop that’s across the street from the hotel”(我想试穿我在酒店对面商店里看到的一套西装)。这是一个著名的句子,它展示了不同语言之间的差异,如下图所示:

如果你将这个句子从英语翻译成法语,你可能可以按照原始顺序逐一匹配单词。然而,这并非一帆风顺,因为英语中的“want”在法语中可能是“veux”、“voulons”或“voulez”,具体取决于它与“I”、“we”或“you”的关联。因此,翻译模型需要**关注**“Je”(法语中“I”的对应词,以确定动词形式)和“want”(以确定动词),才能找到正确的翻译。

翻译成日语甚至更具挑战性,因为日语使用主宾谓(SOV)语序。当模型看到“I want…”时,它需要等到句末才能确定宾语。在模型生成“私は”(日语中“I”的对应词)和“ホテル”(日语中“hotel”的对应词)之后,紧接的词并不会影响翻译,但它必须与英语句子的最后一个词相匹配。

注意力机制的目的是帮助模型聚焦于序列中的相关部分,同时忽略其余部分。

注意力操作

注意力机制是为了解决翻译模型中长距离依赖问题而发明的。在翻译语境中理解它最容易。

让我们考虑已经处理完上一节中的英语句子。现在模型正在逐个生成法语单词。在第一个单词“Je”之后,模型需要决定第二个单词应该是什么。

首先,我们将目前生成的法语单词定义为“查询”(query)序列,将已处理的英语句子定义为“键”(key)序列。注意力操作首先计算注意力分数:

$$
\frac{QK^T}{\sqrt{d}}
$$

这是一个矩阵,其中元素 $(i,j)$ 是第 $i$ 个法语单词和第 $j$ 个英语单词之间对齐的分数。分数越高,这两个单词**对齐**得越紧密。对齐并不意味着等价,而是指示下一步应该关注什么。在这个例子中,对齐应该关注英语句子中的“want”,它应该具有最高的注意力分数。上述公式中的 $\sqrt{d}$ 部分是一个常数,用于缩放注意力分数。你现在可以忽略它的作用。

之后,注意力分数通过 softmax 函数进行归一化:

$$
\text{softmax}\Big(\frac{QK^T}{\sqrt{d}}\Big)
$$

Softmax 函数将注意力分数归一化,使得每一行之和为1。这样做的原因在下一步变得清晰,即计算值的加权和:

$$
O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$$

$V$ 是“值”(value)序列。在这个例子中,你可以将其视为原始句子中每个英语单词的法语翻译。这只是一个简单的翻译,没有考虑任何上下文。在“want”的情况下,它可能是“vouloir”,而不是“veux”。

上面的方程计算了值的加权和。因为我们希望输出 $O$ 是一个词,所以权重之和应该为 1,以防止输出变得过大或过小。为什么要使用加权和而不是选择单个词的翻译呢?因为法语中的一个词可能与英语句子中的多个词相关。在这个例子中,你可能需要 90% 的“want”(“vouloir”)和 10% 的“I”(“je”)来生成正确的动词形式“veux”作为法语的下一个输出词。

上述解释中缺少的一点是,英语句子中的词如何变成法语词,以便你可以在“值”序列中使用它们。这实际上是由一个“投影矩阵”生成的。完整的方程是:

$$
\begin{aligned}
Q &= F W^Q \\
K &= E W^K \\
V &= E W^V \\
X_O &= \Big(\text{softmax}\big(\frac{QK^T}{\sqrt{d}}\big)V\Big)W^O
\end{aligned}
$$

其中 $E$ 和 $F$ 分别是英语单词序列和部分法语单词序列。 $X_O$ 是输出,即法语的下一个单词。 $W^Q$、$W^K$ 和 $W^V$ 是将序列转换到不同空间的投影矩阵。 $W^O$ 是将输出 $O$ 转换回原始空间 $X_O$ 的投影矩阵。

多头注意力(MHA)

上一节的描述只是注意力操作的高层视图。在翻译模型中,输入到注意力的序列实际上不是单词序列,而是词嵌入向量序列。投影矩阵将嵌入向量转换到不同的空间,然后在这个转换后的空间中应用注意力操作。

我们如何确定投影矩阵?这实际上很困难。原因是每个词都可以被转换成多个不同的空间。例如,存在一个“意义空间”来表示词的意义。也可能存在一个词性空间来指示一个词是名词、动词还是形容词。你不需要只选择一个。没有什么可以阻止你并行使用多个空间。

这就是“多头注意力”(MHA)被引入的原因。你使用的不是一组,而是多组投影矩阵,每组都执行自己的注意力操作。然后将输出连接起来以产生最终输出。

因为你有多个相互独立的注意力头,所以你可以并行运行它们。原始的 Transformer 架构使用 8 个注意力头,并被发现在翻译任务中表现良好。

以方程形式,MHA 可以表示为:

$$
\begin{aligned}
\text{Attention}(Q, K, V) &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \\
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_i, X_VW^V_i) \\
\text{MultiHead}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$

在 PyTorch 中,你可以通过实现上述方程来创建自己的 MHA 层:

注意力机制的一个显著特点是,我们通常保持输入和输出序列中的向量维度相同。对于简单的注意力操作来说,这通常不是问题。但对于 MHA,存在多个注意力头,并且输出将沿着向量维度进行拼接。因此,每个头应该在减小的维度 head_dim = d_model // num_heads 中操作,以便进行拼接。

在上面的构造函数中,输入投影矩阵被定义为 q_projk_projv_proj。输出投影矩阵被定义为 out_proj

forward() 函数中,输入 x 通过输入投影矩阵投影到 qkv。输入 x 的形状为 (batch_size, seq_length, d_model)。投影后形状保持不变,但随后被重塑并转置为 (batch_size, num_heads, seq_length, head_dim)。计算 scoresmatmul() 对齐 head_dim 维度(最后一个轴),结果注意力分数的形状为 (batch_size, num_heads, seq_length, seq_length)。然后沿最后一个轴应用 softmax,使该轴上的和为 1。

如果您需要对注意力机制应用**掩码**,例如解码器专用模型中常用的**因果掩码**,您应该在应用 softmax 之前将其应用于 scores 张量。有些模型在 softmax 之后对注意力权重应用**dropout**。据信这有助于模型变得更健壮。

注意力权重随后与 `v` 相乘,结果被转置回形状 `(batch_size, seq_length, num_heads, head_dim)`。`contiguous()` 用于使结果在内存中连续,以便每个头的向量可以使用 `view()` 拼接回原始形状 `(batch_size, seq_length, d_model)`。然后这个张量由 `out_proj` 投影,并作为注意力操作的输出。

上述实现是**自注意力**,因为类中的 `forward()` 函数使用相同的输入 `x` 来创建 `q`、`k` 和 `v`。在**交叉注意力**中,一个输入序列用于 `q`,另一个输入序列用于 `k` 和 `v`。

请注意,在 PyTorch 中,`view()` 是更改张量形状比 `reshape()` 更快的方式,但它要求张量形状是连续的。如果您对某些轴进行了转置,张量将不连续。您应该在张量上调用 `contiguous()` 来移动内存,使其再次连续。在创建 `context` 张量的行中就是这种情况。

实际上,在 PyTorch 中,上述类已经作为 `torch.nn.MultiheadAttention` 实现。您应该使用它。

分组查询注意力(GQA)和多查询注意力(MQA)

多头注意力(MHA)是最强大的注意力机制,但它涉及大量的计算。有多种方法可以降低计算成本。分组查询注意力(GQA)是最受欢迎的一种。

GQA 通过在查询头**组**之间共享键和值投影来降低 MHA 的计算成本:

$$
\begin{aligned}
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_{g(i)}, X_VW^V_{g(i)}) \\
\text{GQA}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$

与MHA相比,GQA对同一组 $g(i)$ 中的所有查询头使用相同的投影矩阵 $W^K_{g(i)}$ 和 $W^V_{g(i)}$。常见的分组方式是均匀分割头部,例如:

$$
\begin{aligned}
g(i) &= \left\lfloor \frac{i}{m} \right\rfloor \\
\therefore\; 0 &= g(0) = g(1) = \cdots = g(m-1) \\
1 &= g(m) = g(m+1) = \cdots = g(2m-1) \\
\vdots \\
\end{aligned}
$$

将上述代码示例修改为 GQA 很容易:

与 MHA 相比,投影矩阵 k_projv_proj 不同。具体来说,这些投影矩阵更小,因此矩阵乘法计算速度更快。

由于投影矩阵 k_projq_proj 的形状不同,因此 qk 之间的乘法是不可能的。因此,您需要使用 repeat_interleave()k 扩展到与 q 相同的形状。这仅在 num_heads 可被 num_groups 整除时才有效。出于同样的原因,您还需要使用 repeat_interleave() 扩展 v

或者,您可以使用 PyTorch 内置的 `scaled_dot_product_attention()` 函数来简化上述实现:

请注意,调用 `scaled_dot_product_attention()` 并设置 `enable_gqa=True` 就可以替换对 `repeat_interleave()`、`matmul()` 和 `softmax()` 的调用。

研究发现,GQA 可以在对模型质量影响最小的情况下,降低内存使用和计算时间。

如果在 GQA 中将分组数设置为 1,它就变成了**多查询注意力**(MQA)。但如果将分组数设置为与查询头数相同,它就会退化为多头注意力。

进一步阅读

以下是一些与该主题相关的论文:

总结

在这篇文章中,您了解了语言模型中的注意力机制。特别是,您了解了:

  • 为什么注意力对于捕捉序列中的关系至关重要
  • 多头注意力如何实现不同类型的关系
  • 分组查询注意力如何平衡效率和性能

暂无评论。

发表评论

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。