RoPE 相对于正余弦位置编码和可学习位置编码,更能够表达相对位置信息,便于模型捕捉序列中元素之间的关系,还便于模型泛化到更长的序列,支持超长文本推理。
RoPE 的思路:对 Q,K 矩阵进行旋转,使计算得到的注意力权重天然带有两个 token 之间的相对距离。
对输入序列 {x0,x1,…},假设我们取两个 token:第 m 个和第 n 个
xm, xn∈Rdmodel方便推导,先假设 dmodel=2。令它们的 Query、Key 为:
qm=Wqxm,kn=Wkxn注意力的核心是内积:
⟨qm, kn⟩=qm⊤knRoPE 的想法是用“旋转”后的 Q/K 来计算内积:
⟨qmrope, knrope⟩其中
qmrope=qmeimθ,knrope=kneinθ
这里补充一下向量旋转的知识点:
复数的形式是:
z=a+bi其中 a 是实部,b 是虚部,i 是虚数单位。如果你把复平面画出来,会发现它和普通的 2D 坐标系完全一样:
- 横轴 = 实轴(real axis)
- 纵轴 = 虚轴(imag axis)
- 一个点 (x,y) 就是复数 x+yi
所以二维向量 (x,y) 可以等价地写成复数 z=x+yi,对于上文的二维向量 xm 有:
qm=[qm1qm2]=qm1+qm2i因为复数的乘法,天然包含了“旋转 + 缩放”的操作,对一个复数 z=x+yi,乘上一个单位模的复数:
eiθ=cosθ+isinθ就会让它绕原点旋转角度 θ,因此 qmrope=qmeimθ 表示把 qm 旋转 mθ。
对任意二维向量应用二维旋转矩阵也可以逆时针旋转:
R(θ)=[cosθsinθ−sinθcosθ]旋转矩阵的效果与 eimθ 相同, 证明如下:
qmrope=qmeimθ=[qm1qm2]eimθ=(qmi+qm2i)(cos(mθ)+isin(mθ))=(qm1cos(mθ)−qm2sin(mθ))+i(qm2cos(mθ)+qm1sin(mθ))=[qm1cos(mθ)−qm2sin(mθ)qm2cos(mθ)+qm1sin(mθ)]=[cos(mθ)sin(mθ)−sin(mθ)cos(mθ)][qm1qm2]=R(mθ)[qm1qm2]
基于前面的向量旋转的知识,我们进而得到:
<qmrope,knrope>=<R(mθ)qm,R(nθ)kn>=qmTRT(mθ)R(nθ)kn=qmTR(−mθ)R(nθ)kn=qmTR((n−m)θ)kn(证明1.1)(证明1.2)至此就可以看出,应用 RoPE 对向量进行旋转后,注意力权重就与两个 token 之间距离相关了。两个 token 距离越远,n-m 越大,旋转角度越大,注意力权重越小。
证明-1.1:
RT(θ)=[cosθ−sinθsinθcosθ]=[cos(−θ)sin(−θ)−sin(−θ)cos(−θ)]=R(−θ)证明-1.2:
R(−mθ)R(nθ) ⟷ e−imθeinθ=ei(n−m)θ=R((n−m)θ)
扩展到高维就有:
可以看到矩阵计算时候有非常多的 0,增大了计算量,简便方法就是:
并且有:
θi=10000−2i/d
def linear_RoPE(qk: torch.Tensor):
# x = [bs, len, dmodel]
_, seq_len, d_model = qk.size()
assert d_model % 2 == 0
position = torch.arange(seq_len, dtype=torch.float) # [max_len]
freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
sinusoid = torch.outer(position, freq)
cos, sin = torch.cos(sinusoid), torch.sin(sinusoid)
even, odd = qk[..., 0::2], qk[..., 1::2]
rotated_even = even * cos - odd * sin
rotated_odd = odd * cos + even * sin
return torch.stack([rotated_even, rotated_odd], dim=-1).reshape_as(qk)
代码的朴素实现就是一比一参考上图。
- 首先计算频率,也就是上图的 mθ,需要注意图片里面是一维向量,但实际上应该是三维的 [batch_size, seq_len, d_model//2]。我们可以构造出 0-seq_len-1 的向量和 θ0-θd/2−1 向量,然后求外积(ps:
torch.outer 等同于 unsqueeze(1) 之后逐点相乘)就能得到 [seq_len, d_model//2]。
- 然后将输入矩阵的 d_model 为按照奇偶分开。
- 最后偶数列就是 “偶数列*cos-奇数列*sin”,奇数列就是 “奇数列*cos+偶数列*sin”
- 通过
torch.stack 叠加在第四维,然后再 reshape 交错拼接在第三维。
def llama_RoPE(qk: torch.Tensor):
_, _, seq_len, dim = qk.shape
assert dim % 2 == 0, "dim must be even"
qk_complex = qk.view(*qk.shape[:-1], dim//2, 2) # [bsize, nheads, seq_len, dim//2, 2]
qk_complex = torch.view_as_complex(qk_complex) # [bsize, nheads, seq_len, dim//2]
position = torch.arange(seq_len, dtype=torch.float) # [max_len]
freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
sinusoid = torch.outer(position, freq)
rot = torch.exp(1j * sinusoid) # [seq_len, dim//2]
# rot = torch.polar(torch.ones_like(sinusoid), sinusoid)
rotated_qk_complex = qk_complex * rot
rotated_qk = torch.view_as_real(rotated_qk_complex) # [bsize, nheads, seq_len, dim//2, 2]
rotated_qk = rotated_qk.view_as(qk)
return rotated_qk
LLaMA 的实现方式更接近 RoPE 最朴素的想法:对 Q/K 进行旋转,它等价于对 Q/K 的每对维度进行一个二维旋转:
\begin{pmatrix}
\cos\theta & -\sin\theta\
\sin\theta & \cos\theta
\end{pmatrix}
\begin{pmatrix}
x_{2i}\
x_{2i+1}
\end{pmatrix}
$$
LLaMA 的想法是 把二维向量看成复数,二维旋转矩阵实际上等价于乘上复数:ejθ=cosθ+jsinθ 。qk_complex 就是将 qk 最后一个维度两两拆开组成复数,然后和单位模复数相乘将其旋转,最后再还原为二维向量。