位置编码中的插值与使用 YaRN 扩展更长的上下文窗口

Transformer 模型在固定序列长度下进行训练,但在推理时,它们可能需要处理不同长度的序列。这带来了挑战,因为位置编码是基于序列长度计算的。模型可能会难以处理在训练期间未遇到过的位置编码。

处理可变序列长度的能力对模型至关重要。本文探讨了不同的位置编码方法如何解决这一挑战。

让我们开始吧。

位置编码中的插值与使用 YaRN 扩展更长的上下文窗口
照片作者:enkuu smile_。部分权利保留。

概述

这篇博文分为三部分;它们是:

  • 正弦编码和 RoPE 中的插值与外插
  • 学习型编码中的插值
  • 用于更大上下文窗口的 YaRN

正弦编码和 RoPE 中的插值与外插

正弦编码因使用连续函数而在外插方面表现出色

$$
\begin{aligned}
PE(p, 2i) &= \sin\left(\frac{p}{10000^{2i/d}}\right) \\
PE(p, 2i+1) &= \cos\left(\frac{p}{10000^{2i/d}}\right)
\end{aligned}
$$

您只需将 $p$ 替换为更大的值,即可获得更长序列的位置编码。这就是外插。

或者,您也可以使用插值。与其使用从 0 到 $L-1$ 的整数 $p$ 作为序列长度 $L$,不如让 $p$ 成为同一范围内的浮点数,以表示长度 $L’>L$。即:

$$p = \frac{L}{L’}p’$$

其中 $p’$ 是从 0 到 $L’-1$ 的整数位置(更长序列中的实际位置)。

这些技术也适用于 RoPE。

生成正弦位置编码或 RoPE 的函数无需修改即可处理任何长度的序列。但是,您可能需要对模型进行微调,以确保它能够处理未在训练期间见过的新编码。例如,Llama 模型使用 RoPE,并以 16K 的最大序列长度进行了训练。Code Llama 是一个从 Llama 微调而来的、专注于编程的模型,它仅通过 1000 次微调步骤就将序列长度扩展到了 100K 令牌。

学习型编码中的插值

学习型位置编码从查找表中检索位置编码向量。这意味着序列长度由表大小固定,无法进行外插。但是,插值仍然可以处理比训练长度更长的序列。对于长度为 $L’>L$(其中 $L$ 是原始序列长度)的序列,位置 $p’=0, \dots, L’-1$ 的编码向量是:

$$P_{p’} = \frac{p-n}{m-n}P_m + \frac{m-p}{m-n}P_n$$

其中 $p = \frac{L}{L’}p’$,并且 $m,n$ 是整数,使得 $m=n+1$ 且 $n\le p\le m$。在 PyTorch 中:

这是一个基本插值实现。但是,不能保证模型在重新训练之前能够处理更长的序列而不降低性能。

用于更大上下文窗口的 YaRN

RoPE 目前是大型语言模型中最广泛使用的位置编码。最近的研究一直致力于改进 RoPE 的外插能力。

YaRN 是一种将 RoPE 扩展到处理更长序列的方法,它比上述插值方法更有效。回想一下 RoPE 正弦波的计算公式:

$$
\begin{aligned}
\theta_i &= \frac{1}{10000^{2i/d}} \\
\hat{x}_m^{(i)} &= x_m^{(i)} \cos(m\theta_i) + x_m^{(d/2+i)} \sin(m\theta_i) \\
\hat{x}_m^{(d/2+i)} &= x_m^{(d/2+i)} \cos(m\theta_i) – x_m^{(i)} \sin(m\theta_i)
\end{aligned}
$$

其中 $m$ 是序列中的位置,向量 $x_m$ 的维度是 $d$,而 $x_m^{(i)}$ 是向量 $x_m$ 的第 $i$ 个元素。YaRN 修改了公式为:

$$
\begin{aligned}
s &= \frac{L’}{L} \\
\theta_i &= \frac{1}{10000^{2i/d}} \\
r(i) &= \frac{L}{2\pi(10000^{2i/d})} \\
\gamma(r) &= \begin{cases}
0, & \text{if } r < \alpha \\ \dfrac{r – \alpha}{\beta – \alpha}, & \text{if } \alpha \le r \le \beta \\ 1, & \text{if } r > \beta
\end{cases} \\
\theta_i’ &= \Big[1-\gamma\big(r(i)\big)\Big]\frac{\theta_i}{s} + \gamma\big(r(i)\big)\theta_i \\
\sqrt{\frac{1}{t}} &= 0.1\ln(s) + 1 \\
\hat{x}_m^{(i)} &= \sqrt{\frac{1}{t}} \Big[x_m^{(i)} \cos(m\theta_i’) + x_m^{(d/2+i)} \sin(m\theta_i’)\Big] \\
\hat{x}_m^{(d/2+i)} &= \sqrt{\frac{1}{t}} \Big[x_m^{(d/2+i)} \cos(m\theta_i’) – x_m^{(i)} \sin(m\theta_i’)\Big]
\end{aligned}
$$

代码实现:

YaRN 的关键创新是在将序列长度从 $L$ 扩展到 $L’$ 时,不对 RoPE 正弦频率进行均匀缩放。这种方法称为“NTK-by-parts”插值。

考虑 RoPE 公式中的 $\cos(m\theta_i)$ 项,其中对于新的序列长度 $L’$, $m$ 的范围是从 0 到 $L’-1$。在 inv_freq_interpolation 中,通过将 $\theta_i$ 乘以因子 $1/s = L/L’$ 来产生插值效果。使用原始的 $\cos(m\theta_i)$ 并使用更大的 $m$ 则构成外插。

在 NTK-by-part 中,您使用 $\cos(m\theta_i’)$ 而不是 $\cos(m\theta_i)$,其中 $\theta_i’$ 混合了插值和外插。插值和外插之间的权重遵循上述公式,在代码中实现为 inv_freq

YaRN 通过添加缩放因子 $\sqrt{1/t}$ 来改进 NTK-by-part。此增强功能通过在更长的上下文长度下降低的困惑度(下一个令牌预测中的更高准确度)来提高模型性能。

进一步阅读

以下是一些与该主题相关的论文:

总结

在本帖中,您了解了经过较短上下文长度训练的模型如何处理更长的输入序列。具体而言:

  • 正弦编码和 RoPE 可以轻松进行外插
  • 学习型编码仅支持插值
  • YaRN 提供了一种先进的方法,用于将 RoPE 扩展到更长的序列长度

缩放位置编码的目的是使模型能够在不重新训练的情况下处理更长的输入序列。这并非详尽无遗的列表,因为更先进的方法将继续建立在这些基本思想的基础上。

暂无评论。

发表回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。