Transformer 模型在固定序列长度下进行训练,但在推理时,它们可能需要处理不同长度的序列。这带来了挑战,因为位置编码是基于序列长度计算的。模型可能会难以处理在训练期间未遇到过的位置编码。
处理可变序列长度的能力对模型至关重要。本文探讨了不同的位置编码方法如何解决这一挑战。
让我们开始吧。

位置编码中的插值与使用 YaRN 扩展更长的上下文窗口
照片作者:enkuu smile_。部分权利保留。
概述
这篇博文分为三部分;它们是:
- 正弦编码和 RoPE 中的插值与外插
- 学习型编码中的插值
- 用于更大上下文窗口的 YaRN
正弦编码和 RoPE 中的插值与外插
正弦编码因使用连续函数而在外插方面表现出色
$$
\begin{aligned}
PE(p, 2i) &= \sin\left(\frac{p}{10000^{2i/d}}\right) \\
PE(p, 2i+1) &= \cos\left(\frac{p}{10000^{2i/d}}\right)
\end{aligned}
$$
您只需将 $p$ 替换为更大的值,即可获得更长序列的位置编码。这就是外插。
或者,您也可以使用插值。与其使用从 0 到 $L-1$ 的整数 $p$ 作为序列长度 $L$,不如让 $p$ 成为同一范围内的浮点数,以表示长度 $L’>L$。即:
$$p = \frac{L}{L’}p’$$
其中 $p’$ 是从 0 到 $L’-1$ 的整数位置(更长序列中的实际位置)。
这些技术也适用于 RoPE。
生成正弦位置编码或 RoPE 的函数无需修改即可处理任何长度的序列。但是,您可能需要对模型进行微调,以确保它能够处理未在训练期间见过的新编码。例如,Llama 模型使用 RoPE,并以 16K 的最大序列长度进行了训练。Code Llama 是一个从 Llama 微调而来的、专注于编程的模型,它仅通过 1000 次微调步骤就将序列长度扩展到了 100K 令牌。
学习型编码中的插值
学习型位置编码从查找表中检索位置编码向量。这意味着序列长度由表大小固定,无法进行外插。但是,插值仍然可以处理比训练长度更长的序列。对于长度为 $L’>L$(其中 $L$ 是原始序列长度)的序列,位置 $p’=0, \dots, L’-1$ 的编码向量是:
$$P_{p’} = \frac{p-n}{m-n}P_m + \frac{m-p}{m-n}P_n$$
其中 $p = \frac{L}{L’}p’$,并且 $m,n$ 是整数,使得 $m=n+1$ 且 $n\le p\le m$。在 PyTorch 中:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
class ExtrapolatingLearnedEncoding(nn.Module): def __init__(self, max_trained_len, d): super().__init__() self.max_trained_len = max_trained_len self.position_embeddings = nn.Embedding(max_trained_len, d) def forward(self, x) seq_len = x.size(1) if seq_len <= self.max_trained_len: # 正常情况:使用学习到的嵌入 positions = torch.arange(seq_len, device=x.device) return x + self.position_embeddings(positions) else: # 外插情况:使用插值 positions = torch.arange(seq_len, device=x.device) # 在现有位置之间进行插值 scale = (self.max_trained_len - 1) / (seq_len - 1) scaled_positions = positions * scale # 获取向下取整和向上取整的位置 pos_floor = torch.floor(scaled_positions).long() pos_ceil = torch.ceil(scaled_positions).long() # 获取插值权重 weights = (scaled_positions - pos_floor.float()).unsqueeze(-1) # 进行插值 emb_floor = self.position_embeddings(pos_floor) emb_ceil = self.position_embeddings(pos_ceil) return x + (1 - weights) * emb_floor + weights * emb_ceil |
这是一个基本插值实现。但是,不能保证模型在重新训练之前能够处理更长的序列而不降低性能。
用于更大上下文窗口的 YaRN
RoPE 目前是大型语言模型中最广泛使用的位置编码。最近的研究一直致力于改进 RoPE 的外插能力。
YaRN 是一种将 RoPE 扩展到处理更长序列的方法,它比上述插值方法更有效。回想一下 RoPE 正弦波的计算公式:
$$
\begin{aligned}
\theta_i &= \frac{1}{10000^{2i/d}} \\
\hat{x}_m^{(i)} &= x_m^{(i)} \cos(m\theta_i) + x_m^{(d/2+i)} \sin(m\theta_i) \\
\hat{x}_m^{(d/2+i)} &= x_m^{(d/2+i)} \cos(m\theta_i) – x_m^{(i)} \sin(m\theta_i)
\end{aligned}
$$
其中 $m$ 是序列中的位置,向量 $x_m$ 的维度是 $d$,而 $x_m^{(i)}$ 是向量 $x_m$ 的第 $i$ 个元素。YaRN 修改了公式为:
$$
\begin{aligned}
s &= \frac{L’}{L} \\
\theta_i &= \frac{1}{10000^{2i/d}} \\
r(i) &= \frac{L}{2\pi(10000^{2i/d})} \\
\gamma(r) &= \begin{cases}
0, & \text{if } r < \alpha \\ \dfrac{r – \alpha}{\beta – \alpha}, & \text{if } \alpha \le r \le \beta \\ 1, & \text{if } r > \beta
\end{cases} \\
\theta_i’ &= \Big[1-\gamma\big(r(i)\big)\Big]\frac{\theta_i}{s} + \gamma\big(r(i)\big)\theta_i \\
\sqrt{\frac{1}{t}} &= 0.1\ln(s) + 1 \\
\hat{x}_m^{(i)} &= \sqrt{\frac{1}{t}} \Big[x_m^{(i)} \cos(m\theta_i’) + x_m^{(d/2+i)} \sin(m\theta_i’)\Big] \\
\hat{x}_m^{(d/2+i)} &= \sqrt{\frac{1}{t}} \Big[x_m^{(d/2+i)} \cos(m\theta_i’) – x_m^{(i)} \sin(m\theta_i’)\Big]
\end{aligned}
$$
代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch import numpy as np def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(x, cos, sin): return (x * cos) + (rotate_half(x) * sin) class YaRN(nn.Module): def __init__(self, dim, orig_seq_len=512, scale=4, alpha=1, beta=32): super().__init__() N = 10000 pos_freq = N ** (torch.arange(0, dim, 2).float() / dim) inv_freq_extrapolation = 1. / pos_freq inv_freq_interpolation = 1. / (scale * pos_freq) low = dim * np.log(orig_seq_len / (2*np.pi*beta)) / (2*np.log(N)) high = dim * np.log(orig_seq_len / (2*np.pi*alpha)) / (2*np.log(N)) low = max(np.floor(low), 0) high = min(np.ceil(high), dim-1) linear_func = (torch.arange(dim // 2).float() - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) inv_freq_factor = 1 - ramp_func inv_freq = inv_freq_interpolation * (1-inv_freq_factor) + inv_freq_extrapolation * inv_freq_factor # 原始 RoPE 乘以一个缩放因子 scaling_factor = 0.1 * np.log(scale) + 1.0 position = torch.arange(orig_seq_len * scale).float() sinusoid_inp = torch.outer(position, inv_freq) self.register_buffer("cos", sinusoid_inp.cos() * scaling_factor) self.register_buffer("sin", sinusoid_inp.sin() * scaling_factor) def forward(self, x, seq_len=None): if seq_len is None: seq_len = x.size(1) cos = self.cos[:seq_len].view(1, seq_len, 1, -1) sin = self.sin[:seq_len].view(1, seq_len, 1, -1) return apply_rotary_pos_emb(x, cos, sin) |
YaRN 的关键创新是在将序列长度从 $L$ 扩展到 $L’$ 时,不对 RoPE 正弦频率进行均匀缩放。这种方法称为“NTK-by-parts”插值。
考虑 RoPE 公式中的 $\cos(m\theta_i)$ 项,其中对于新的序列长度 $L’$, $m$ 的范围是从 0 到 $L’-1$。在 inv_freq_interpolation
中,通过将 $\theta_i$ 乘以因子 $1/s = L/L’$ 来产生插值效果。使用原始的 $\cos(m\theta_i)$ 并使用更大的 $m$ 则构成外插。
在 NTK-by-part 中,您使用 $\cos(m\theta_i’)$ 而不是 $\cos(m\theta_i)$,其中 $\theta_i’$ 混合了插值和外插。插值和外插之间的权重遵循上述公式,在代码中实现为 inv_freq
。
YaRN 通过添加缩放因子 $\sqrt{1/t}$ 来改进 NTK-by-part。此增强功能通过在更长的上下文长度下降低的困惑度(下一个令牌预测中的更高准确度)来提高模型性能。
进一步阅读
以下是一些与该主题相关的论文:
- Transformers 的长度外插:从位置编码的视角进行调查
- Code Llama:用于代码的开放基础模型
- Mesa-Extrapolation:一种用于增强 LLM 外插能力的编织位置编码方法
- YaRN:大型语言模型的有效上下文窗口扩展
总结
在本帖中,您了解了经过较短上下文长度训练的模型如何处理更长的输入序列。具体而言:
- 正弦编码和 RoPE 可以轻松进行外插
- 学习型编码仅支持插值
- YaRN 提供了一种先进的方法,用于将 RoPE 扩展到更长的序列长度
缩放位置编码的目的是使模型能够在不重新训练的情况下处理更长的输入序列。这并非详尽无遗的列表,因为更先进的方法将继续建立在这些基本思想的基础上。
暂无评论。