Transformer 模型中的注意力机制需要处理各种约束,以防止模型关注某些位置。本文探讨了注意力掩码如何实现这些约束及其在现代语言模型中的实现。
让我们开始吧。

Transformer 模型中注意力掩码(Attention Masking)的简明介绍
图片来源:Caleb Jack。保留部分权利。
概述
本文分为四个部分;它们是:
- 为什么需要注意力掩码
- 注意力掩码的实现
- 掩码创建
- 使用 PyTorch 内置的注意力机制
为什么需要注意力掩码
在上一篇文章中,你了解了注意力机制如何让模型关注序列的相关部分。然而,在某些情况下,你希望阻止模型关注某些位置。
- 因果掩码:在语言建模和文本生成中,模型应该只关注之前的词元,而不是未来的词元。因果掩码可以防止在训练过程中未来信息泄露。
- 填充掩码:当处理不同长度的序列批次时,较短的序列会用特殊词元进行填充。模型应该忽略这些填充词元。这是推理过程中最常见的掩码用法。
- 自定义掩码:在某些应用中,我们可能希望根据特定领域的规则阻止对特定词元或位置的关注。
考虑语言模型正在学习的句子“The cat sat on the mat”。当训练模型预测单词“sat”时,它应该只考虑“The cat”,而不是“on the mat”,以避免通过查看未来信息作弊。
对于因果掩码,如果你以“The cat sat on the mat”作为输入来训练模型,你将使用以下掩码:
$$
\begin{bmatrix}
1 & 0 & 0 & 0 & 0 & 0 \\
1 & 1 & 0 & 0 & 0 & 0 \\
1 & 1 & 1 & 0 & 0 & 0 \\
1 & 1 & 1 & 1 & 0 & 0 \\
1 & 1 & 1 & 1 & 1 & 0 \\
1 & 1 & 1 & 1 & 1 & 1
\end{bmatrix}
$$
此掩码是一个全 1 的下三角矩阵。元素 (i,j) 为 1 表示查询词元 i 可以关注键词元 j。下三角结构确保键序列的长度永远不会超过查询序列的长度,即使在训练期间将完整序列输入到模型中也是如此。
某些模型(如 BERT)是“双向的”,它们预测被掩码的词元而不是下一个词元。这些模型使用在随机位置包含 0 的掩码进行训练。
在推理过程中,你可能会向模型传递一批序列
1 2 |
[["The", "cat", "sat", "on", "the", "mat"], ["Once", "upon", "a", "time"]] |
此批次包含两个长度不等的序列。经过预处理和填充后
1 2 3 |
[["The", "cat", "sat", "on", "the", "mat"], ["In", "the", "beginning", "<PAD>", "<PAD>", "<PAD>"], ["Once", "upon", "a", "time", "<PAD>", "<PAD>"]] |
为确保模型忽略填充词元,你创建了一个如下所示的掩码
$$
\begin{bmatrix}
1 & 1 & 1 & 1 & 1 & 1 \\
1 & 1 & 1 & 0 & 0 & 0 \\
1 & 1 & 1 & 1 & 0 & 0
\end{bmatrix}
$$
在这里,对应于填充词元的位置设置为 0,而所有其他位置设置为 1。
注意力掩码的实现
在上一个帖子中注意力模块的基础上,你可以修改它以支持掩码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads,dropout_prob=0): super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.dropout_prob = dropout_prob self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, mask=None): batch_size = x.size(0) seq_length = x.size(1) # Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # 计算注意力分数 scores = torch.matmul(q, k.transpose(2,3)) / math.sqrt(self.head_dim) # 对注意力分数应用掩码 if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) # 应用 softmax 计算注意力权重 attn_weights = F.softmax(scores, dim=-1) # 应用 dropout if self.dropout_prob: attn_weights = F.dropout(attn_weights, p=self.dropout_prob) # 将注意力权重应用于值 context = torch.matmul(attn_weights, v).transpose(1,2).contiguous() context = context.view(batch_size, seq_length, self.d_model) return self.out_proj(context) |
这是带有掩码和 dropout 的多头注意力的标准实现。掩码在 softmax 之前应用于注意力分数。用数学术语来说,掩码是一个矩阵 $M$,使得
$$
\text{Attention}(Q, K, V, M) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}} + M\right)V
$$
掩码必须在 softmax 之前添加,因为 softmax 在整行上操作。你不想让 softmax 考虑被掩码的元素。由于 softmax 计算
$$
\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{j=1}^n \exp(x_j)}
$$
为了使被掩码的元素对 softmax 的贡献为 0,你需要在这些位置添加 $-\infty$。这就是 `masked_fill()` 函数所实现的功能。
鉴于此实现,如果掩码是 $-\infty$ 和 0 值的矩阵,你也可以直接使用它
1 2 3 |
... if mask is not None: scores = scores + mask |
下一节将向你展示如何为不同的用例创建掩码。
掩码创建
由于掩码是必不可少且广泛使用的,因此创建专门用于掩码生成的函数非常有价值
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch def create_causal_mask(seq_len): """ 为自回归注意力创建一个因果掩码。 参数 seq_len: 序列的长度 返回 形状为 (seq_len, seq_len) 的因果掩码 """ mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1) return mask def create_padding_mask(batch, padding_token_id): """ 为一批序列创建填充掩码。 参数 batch: 序列批次,形状为 (batch_size, seq_len) padding_token_id: 填充词元的 ID 返回 形状为 (batch_size, seq_len, seq_len) 的填充掩码 """ batch_size, seq_len = batch.shape padded = torch.zeros_like(batch).float().masked_fill(batch == padding_token_id, float('-inf')) mask = torch.zeros(batch_size, seq_len, seq_len) + padded[:,:,None] + padded[:,None,:] return mask[:, None, :, :] print(create_causal_mask(5)) batch = torch.tensor([ [1, 2, 3, 4, 5, 6], [1, 2, 3, 0, 0, 0], [1, 2, 3, 4, 0, 0] ]) print(create_padding_mask(batch, 0)) |
这是两种最常见的掩码类型。你可以将它们扩展到其他用例。在 `create_causal_mask()` 中,你创建了一个对角线上方为 $-\infty$ 值的上三角矩阵。值为 0 的位置允许注意力。
在 `create_padding_mask()` 中,你首先使用与 `batch` 具有相同形状的 `padded` 张量识别批次中的填充词元。输出掩码的形状为 `(batch_size, seq_len, seq_len)`,最初所有元素为 0,然后通过两次添加 `padded` 张量进行修改:一次用于行,一次用于列。
运行此代码会产生
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
tensor([[0., -inf, -inf, -inf, -inf], [0., 0., -inf, -inf, -inf], [0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf], [0., 0., 0., 0., 0.]]) tensor([[[[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]], [[[0., 0., 0., -inf, -inf, -inf], [0., 0., 0., -inf, -inf, -inf], [0., 0., 0., -inf, -inf, -inf], [-inf, -inf, -inf, -inf, -inf, -inf], [-inf, -inf, -inf, -inf, -inf, -inf], [-inf, -inf, -inf, -inf, -inf, -inf]]], [[[0., 0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf, -inf], [-inf, -inf, -inf, -inf, -inf, -inf], [-inf, -inf, -inf, -inf, -inf, -inf]]]]) |
这些掩码可以直接用作上面 `MultiHeadAttention` 类的 `forward()` 方法中的 `mask` 参数。
使用 PyTorch 内置的带掩码注意力机制
上面 `forward()` 方法中的矩阵乘法和 softmax 操作可以用 PyTorch 内置的 SDPA 函数替换
1 2 |
... context = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_prob) |
代码的所有其他部分保持不变,包括投影矩阵和掩码创建函数。
或者,你可以使用 PyTorch 内置的 `MultiheadAttention` 类。将其与掩码一起使用非常简单
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch dim = 16 num_heads = 4 attn_layer = torch.nn.MultiheadAttention(dim, num_heads, dropout=0.1, batch_first=True) # 输入张量:0 = 填充 batch = torch.tensor([ [1, 2, 3, 4, 5, 6], [1, 2, 3, 0, 0, 0], [1, 2, 3, 4, 0, 0] ]) batch_size, seq_len = batch.shape x = torch.randn(batch_size, seq_len, dim) padding_mask = (batch == 0) y = attn_layer(x, x, x, key_padding_mask=padding_mask, attn_mask=None) |
创建注意力层时,你只需指定维度大小和头数。该类在内部处理所有投影矩阵和 dropout。请注意,你应设置 `batch_first=True` 以使用形状为 `(batch_size, seq_len, dim)` 的输入张量。
上面的代码演示了使用 `MultiheadAttention` 进行自注意力,其中相同的张量 `x` 充当查询、键和值。如果你的输入张量包含填充词元,你可以使用 `key_padding_mask` 来指示应该掩码哪些注意力位置。
为了更精确地控制注意力掩码,你可以使用 `attn_mask` 参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch def create_mask(query, key, padding_token_id): """ 为一批序列创建填充掩码。 参数 query: 用于查询的序列批次,形状为 (batch_size, query_len) key: 用于键的序列批次,形状为 (batch_size, key_len) padding_token_id: 填充词元的 ID 返回 形状为 (batch_size, query_len, key_len) 的填充掩码 """ batch_size, query_len = query.shape _, key_len = key.shape q_padded = torch.zeros_like(query).float().masked_fill(query == padding_token_id, float('-inf')) k_padded = torch.zeros_like(key).float().masked_fill(key == padding_token_id, float('-inf')) mask = torch.zeros(batch_size, query_len, key_len) + q_padded[:,:,None] + k_padded[:,None,:] return mask dim = 16 num_heads = 4 attn_layer = torch.nn.MultiheadAttention(dim, num_heads, dropout=0.1, batch_first=True) # 输入张量:0 = 填充 batch = torch.tensor([ [1, 2, 3, 4, 5, 6], [1, 2, 3, 0, 0, 0], [1, 2, 3, 4, 0, 0] ]) batch_size, seq_len = batch.shape x = torch.randn(batch_size, seq_len, dim) attn_mask = create_mask(batch, batch, 0) attn_mask = attn_mask.repeat(1, num_heads, 1, 1).view(-1, seq_len, seq_len) y = attn_layer(x, x, x, key_padding_mask=None, attn_mask=attn_mask) |
使用 `attn_mask` 需要更多的设置,因为它期望一个形状为 `(batch_size * num_heads, query_len, key_len)` 的 3D 掩码。`create_mask()` 函数创建了一个形状为 `(batch_size, query_len, key_len)` 的 3D 掩码,表示查询-键矩阵中的填充词元位置。然后,你使用 `repeat()` 将掩码复制到每个注意力头。这是内置 `MultiHeadAttention` 类期望的格式。
进一步阅读
以下是一些您可能会觉得有用的资源:
- 注意力就是你所需要的一切
- PyTorch MultiheadAttention API 文档
- PyTorch Scaled Dot Product Attention API 文档
- 多头注意力 (MHA)
总结
在这篇文章中,你了解了 Transformer 模型中的注意力掩码。具体来说,你了解了
- 为什么注意力掩码对于防止信息泄露和处理变长序列是必要的
- 不同类型的掩码及其应用
- 如何在自定义和 PyTorch 内置的注意力机制中实现注意力掩码
暂无评论。