type
status
date
slug
summary
tags
category
icon
password
本文主要翻译及调整来自 Understanding Rotary Positional Encoding | by Ngieng Kianyew | Medium,并做了一些错误修订及扩充。
旋转位置编码是最先进的 NLP 位置嵌入技术。大多数流行的大型语言模型(如 Llama、Llama2、Qwen、PaLM 和 CodeGen)已经在使用它,而不是原始"Attention Is All You Need"论文中使用的绝对位置编码或 Self-Attention with Relative Position Representations 中提出的相对位置编码。
在旋转位置编码之前,对于绝对位置编码还是相对位置编码最好没有明确的答案,但一如既往,在比较两种技术时,答案是将它们全部结合起来(旋转位置编码)。
本文将分为 4 个部分:
- 绝对位置编码的主要缺点
- 相对位置编码的主要缺点
- 什么是旋转位置编码及其工作原理
- 旋转位置编码如何克服绝对和相对位置编码的缺点
引子
我们知道我们需要位置编码,因为自注意力不考虑单词在序列中的位置。绝对位置编码是位置编码的一种形式,我们使用正弦函数创建位置嵌入 ,其中每个元素 ( 对应于位置 。绝对位置编码按元素添加到input token embedding
中相对位置编码是位置编码的一种形式,我们使用一个位置和其他位置之间的成对距离。相对位置编码在 层之前按元素添加到形状 的注意力矩阵中。
绝对位置编码的主要缺点
不包含相对位置信息
虽然绝对位置编码捕获单词的位置信息,但它不会捕获整个句子(或序列)的位置信息。
例子:
- 创建长度为 3 的绝对位置编码的常见方法是进行随机初始化。假设我们得到以下内容:
[0.1, 0.01, 0.5]
(这种绝对位置编码将确保中的相同单词不同的位置会有不同的注意力输出)。
- 但是,请注意如果我们要分析绝对位置编码
[0.1, 0.01, 0.5]
会发生什么: a. 位置之间没有关系。 较大位置索引处的位置编码可以大于或小于较小位置索引 position=1
(0.1) 处的位置编码可以大于position=2
(0.01),即[0.1 > 0.01]
position=1
处的位置编码(0.1)也可以小于position=3
处的位置编码(0.5),即[0.1 < 0.5]
b. 相对距离不一致。 位置编码的差异并不能告诉我们单词之间的距离有多远。position=1
到position=2
的距离是 。position=1
到position=3
的距离为 。(理想情况下position=1
到position=3
的距离应大于position=1
到position=2
的距离)
- 这意味着绝对位置编码不会捕获整个句子(或序列)的位置信息。
相对位置编码的主要缺点
- 计算效率低下 需要创建一个额外的步骤来进行自注意力。请记住,我们必须创建成对的位置编码矩阵,然后执行相当多的张量操作以获得每个时间步的相对位置编码。
- 不适合推理
- 在推理过程中,研究人员喜欢使用一种称为 KV cache 的方法,这有助于提高推理速度。
- 使用 KV cache 的一项要求是已经生成的单词的位置编码,在生成新单词时不改变(绝对位置编码提供)。
- 因此,相对位置编码不适合推理,因为每个标记的嵌入会随着每个新时间步的变化而变化。
例子:
- 当序列长度为 2 时,一个单词的相对位置将为
[-1, 0, 1]
。
- 当输入序列的长度为 3 时,这个单词的相对位置将为
[-2,-1,0,1,2]
。
- 由于我们使用这些相对位置来获取每个单词的相对位置编码,因此当相对位置集发生变化时,每个单词的位置编码也会发生变化。
- 这意味着,给定句子
['this'、'is'、'awesome']
,当您生成用于推理的token时,您对单词"this"的位置编码会在每个时间步发生变化。
这就是相对位置编码不常用的原因。
什么是旋转位置编码
旋转位置编码是一种位置编码,它使用旋转矩阵对绝对位置信息进行编码,并自然地将显式相对位置依赖性纳入自注意力公式中
什么是旋转矩阵
顾名思义,旋转矩阵是将一个向量旋转某个角度到另一个向量的矩阵。
旋转矩阵源自我们在高中学到的正弦和余弦的三角性质,使用二维矩阵应该足以获得旋转矩阵的直觉,如下所示! (如果您不确定,请参阅下面的链接了解二维旋转矩阵的推导)
我们看到旋转矩阵保留了原始向量的大小(或长度),如上图中的"r"所示,唯一改变的是与 x 轴的角度。
二维旋转矩阵如何参与旋转位置编码?
旋转位置编码中使用的旋转矩阵是二维旋转矩阵的多个块,如上所示
利用二维旋转矩阵来旋转向量!
因为我们的向量通常都是二维以上的。理解二维旋转矩阵比较容易,所以作者想到了一个聪明的方法,使用二维旋转矩阵来旋转整个 维向量。这个想法是对每对维度(每组大小为 2 的向量)按某个角度使用不同的二维旋转矩阵。由于我们使用的是"维度对",因此我们要求"D"能被 2 整除,这应该不是问题,因为在深度学习中,D 通常被选择为 2 的某个指数(即 ),可被 2 整除
解释上图中的数学符号:
是单词的位置。 请注意, 对于所有二维向量对都是相同的
是选择用于旋转向量的某个标量值。
请注意, 的下标是对应于一对二维向量的索引。
每对 向量旋转的角度为 ,其中 是对应的 向量对的索引。这意味着每对 向量将旋转不同的角度。
例子:
给定句子
[this, is, Awesome]
,让我们看看它如何旋转单词"this"和"is"的位置编码。
假设单词的嵌入是 6 维向量:我们知道:
单词
this
位于 position 0,因此 。(请注意,对于单词 this,这在所有旋转矩阵中都是相同的)
单词 is 位于positon 1,因此 。(请注意,这在单词 is 的所有旋转矩阵中都是相同的)由于我们有 6 维向量,这意味着每个单词有 对 向量(还有 3 个旋转矩阵)
对于每个单词,由于我们有 3 对 向量,因此我们将拥有 3 个 ,它们在单词 'this' 和 'is' 之间共享
假设 为 [0.1, 0.4, 0.6]
由于我们有 3 对 向量(在我们的示例中嵌入是 6 维向量),因此我们将有:
(1) 每个单词有 3 个旋转矩阵。
(2) 每个单词 3个角度(由 计算)用来旋转。
'this' 的 3 个角度: .
'is' 的 3 个角度: .
如果我们将这些值代入上面的旋转矩阵,并乘以单词"this"的 6 维向量
我们将
[x1,x2]
旋转角度
我们将 [x3,x4]
旋转角度
我们将 [x5,x6]
旋转角度 如果我们将这些值代入上面的旋转矩阵,并乘以单词"is"的 6 维向量
我们将
[x7,x8]
旋转角度
我们将 [x9,x10]
旋转角度
我们将 [x11,x12]
旋转角度 这意味着每个单词嵌入都由 3 对旋转矩阵旋转,每个旋转矩阵都有不同的旋转,这些旋转受每个单词的位置影响。
这意味着位置编码将考虑单词的位置,我们现在将能够为不同位置的同一个单词生成不同的注意力输出。
旋转位置编码如何克服绝对和相对位置编码的缺点
旋转位置编码将两者结合起来克服了这些缺点。
旋转位置编码中的绝对位置编码
请注意,在我们针对"this"和"is"进行旋转位置编码的示例中,旋转矩阵仅取决于 ,其中 在所有单词之间共享,并且 只是单词的位置(不用关心)至于其他词的其他位置是什么)
与绝对位置编码类似,我们对一个位置有一种位置编码(不依赖于其他单词的位置)
旋转位置编码中的相对位置编码
由于旋转角度取决于 ,因此句子开头的单词会有较小的旋转角度,而到达句子末尾的单词会有较大的旋转角度,使得距离是相对的。
请记住: 对于所有单词都是相同的, 单词的位置是唯一改变的。 越高意味着,意味着旋转角度越大。
与相对位置编码类似,不同位置的位置编码之间存在关系
现在,为了解释为什么它克服了缺点,让我们回顾一下这两种方法的缺点:
绝对位置编码:(1)不包括相对位置信息相对位置编码:(2)计算效率低,(3)不适合推理
因此我们知道旋转位置编码必须具有以下性质:
- 包括相对位置信息
- 提高计算效率
- 适合推理
性质-1:旋转位置编码包括相对位置信息
- 旋转位置编码通过使旋转矩阵中的角度取决于当前单词的位置来包含相对位置信息,并通过旋转矩阵的属性告诉我们两个单词相距多远。
- 例如,在上面的图 4 中,"pig"和"dog"之间的角度是相同的,即使它们 (1) 出现在长度不同的 2 个句子中,并且 (2) "pig"和"dog"出现在句子的不同位置!
如果计算两个句子中"pig"和"dog"之间的点积,它们将相同
单词"pig"和"dog"总是相隔 2 个单词("chased the"位于"pig"和"dog"之间)。
请记住,单词的旋转角度为 ,而 "pig" 和 "dog" 之间的角度相隔 2 个单词的"dog"为
另一种看待它的方法是绘制不同位置的旋转矩阵内的角度:
请记住,旋转矩阵内部的角度表示为:
y 轴表示旋转角度
x 轴表示隐藏维度
对于下面的图表,您应该查看整个图表的向上移动,以研究更改 (单词的位置)的影响。
从图中我们可以看到,当
t
从 1 增加到 7 时,旋转角度也随之增大(如整个图的移动所示)。这意味着如果某个单词的位置编码旋转角度较大,则该单词沿着句子更远。 → t=1
和 t=3
之间的间隙大于 t=1
和 t=2
之间的间隙从图中,我们注意到连续两对位置之间的差异大小是相同的。
这意味着一个时间步长的变化带来了相对相同的旋转角度变化。例如,
t=1
和 t=2
之间的间隙与 t=2
和 t=3
之间的间隙相同,依此类推。这意味着相对距离一致。例如,任何相距 3 个单词的两个单词将具有相同的旋转角度。
性质-的另一个视角
到目前为止我们只讨论了"旋转角度",但是编码一个单词到其他位置的相对距离的部分在哪里?
答案在于以下点积的定义:
"向量A和B的点积等于A的长度乘以B的长度乘以它们之间的角度的余弦"
数学表达式为:
其中:
- 表示向量A和B的点积
- 和 分别表示向量A和B的长度
- 表示向量A和B之间的夹角
- 表示夹角的余弦值
如果我们对位置 1 的旋转位置编码与其他位置进行点积,我们会得到 ,它仅取决于旋转角度(更明确地说,是与位置 1 的旋转角度之差)到另一个单词的位置)。
位置 = 1 的旋转角度小于位置 > 1 的旋转角度。这意味着另一个词距离位置 1 (1) 越远,旋转角度的差异就越大 (2)点积越小(因为余弦值从 下降到 ) (3) 另一个词距离位置 1 越远
我们得出结论,旋转位置编码包含相对位置信息
如何使用旋转位置编码:
- 因为我们需要点积包含相对位置编码,所以我们分别对 Query 和 Key 应用旋转位置编码,这样当我们将它们矩阵相乘时,注意力矩阵就包含相对位置编码信息。 (参考这个repo)
性质-:旋转位置编码计算效率高
- 使用旋转位置编码似乎计算效率不高,因为我们必须创建旋转矩阵,但研究人员发现了一个计算效率高的公式。
Show Image
性质-:旋转位置编码适合推理
由于我们已经知道旋转位置编码类似于绝对位置编码,编码仅取决于当前单词的位置,因此已经生成的单词的位置编码不会改变,从而在推理过程中再次可以进行 KV cache
实现
- 生成旋转角度
- 应用旋转
在实际使用时,通常只需要在计算attention scores之前对query和key应用这个旋转即可。
结论
绝对位置编码的缺点是它没有考虑句子中的相对位置信息
相对位置编码的缺点是计算量大且不适合推理
旋转位置编码通过利用每种方法的优点将两者结合起来,并且还利用计算效率高的方法来有效地计算它
旋转位置编码的基本思想是利用 二维旋转矩阵来旋转 维矩阵,其中每个二维旋转矩阵的旋转角度由 决定(其中 为这个词和 是一个预先计算的术语,在所有词中共享),并产生一条指数曲线。当 增加时,指数曲线向上移动到更大的值,表示更高的位置距离。
参考文献
Peter Shaw, Jakob Uszkoreit, and Ashish Vaswani. 2018. Self-Attention with Relative Position Representations. arXiv:1803.02155 [cs].
Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. 2023. RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv: 2104.09864 [cs].
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2023. Attention Is All You Need. arXiv:1706.03762 [cs].