-
?? RoPE为苏剑林大佬之作,最早应用于他自研的RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》 -
?? 据我了解,最近发布的大语言模型:Meta的LLaMA、清华的ChatGLM都采用了RoPE。这也足以证明了RoPE的优势。 -
?? 本文讲解下个人对RoPE原理的理解以及自己用torch复现了一下,更详细地请参阅苏神的原文(文末已附上链接)。 -
?? 如对RoPE公式推导有任何疑问,可评论区或私信反馈,我将做出详细解答。
文章目录
- 1、RoPE 动机
-
- 1.1、绝对位置编码
- 1.2、相对位置编码
- 1.3、RoPE
- 2、RoPE 原理
-
- 2.1、将待解问题公式化(提出假设)
- 2.2、推导求解
- 2.3、RoPE的编码形式
- 3、RoPE 代码实现(torch版)
- Reference
1、RoPE 动机
1.1、绝对位置编码
-
最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。
-
绝对位置编码的讲解可看我的博客:随记·手撕coding | absolute positional embedding
-
优势: 实现简单,可预先计算好,不用参与训练,速度快。
-
劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。
1.2、相对位置编码
- 经典相对位置编码RPR式的讲解可看我的博客:相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记 【在k, v中注入相对位置信息】
- 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。
1.3、RoPE
- RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。
- 主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了。
2、RoPE 原理
- 有关复数相关基础知识可看这:数学 | 复数的代数、向量、矩阵、极坐标、指数形式 | 复数相乘的物理意义【旋转+缩放】
2.1、将待解问题公式化(提出假设)
首先,假设新的qk向量(即假设已注入绝对位置信息)的内积会引入相对位置信息。并在最后假设合理的初始化条件:
2.2、推导求解
不是一般性,考虑其q,k向量为二维的情形,借助复数域推导出为q,k向量编码绝对位置信息的函数 f 。
别看公式多,理解起来并不难。下面我细说一下其中几个关键的推导步骤:
- 式(8) 的推导:
2.3、RoPE的编码形式
上面我们设了q,k的绝对位置编码函数为:
然后又求出了:
而:
那带入(4)式就可以得出q,k的绝对位置编码函数了(下面以q为例,k同理)
为避免这个正交矩阵过于稀疏,浪费算力,代码实现时都是依据下面公式来计算RoPE:
注:苏神在θ的选择上沿用了tansformer的θi = 10000-2i/d 。因为苏神实验发现,在RoPE中采用这个θ也可以带来一定的远程衰减性(意思就是token之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的θ也可,只要满足远程衰减。
3、RoPE 代码实现(torch版)
- 代码实现基于torch,代码中也写好详细注释。如有错误,评论区或私信我反馈,谢谢~
import torch import torch.nn as nn import torch.nn.functional as F import math # %% def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device): # (max_len, 1) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) # (output_dim//2) ids = torch.arange(0, output_dim // 2, dtype=torch.float) # 即公式里的i, i的范围是 [0,d/2] theta = torch.pow(10000, -2 * ids / output_dim) # (max_len, output_dim//2) embeddings = position * theta # 即公式里的:pos / (10000^(2i/d)) # (max_len, output_dim//2, 2) embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) # (bs, head, max_len, output_dim//2, 2) embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # 在bs维度重复,其他维度都是1不重复 # (bs, head, max_len, output_dim) # reshape后就是:偶数sin, 奇数cos了 embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim)) embeddings = embeddings.to(device) return embeddings # %% def RoPE(q, k): # q,k: (bs, head, max_len, output_dim) batch_size = q.shape[0] nums_head = q.shape[1] max_len = q.shape[2] output_dim = q.shape[-1] # (bs, head, max_len, output_dim) pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device) # cos_pos,sin_pos: (bs, head, max_len, output_dim) # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3) cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制 sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制 # q,k: (bs, head, max_len, output_dim) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) # reshape后就是正负交替了 # 更新qw, *对应位置相乘 q = q * cos_pos + q2 * sin_pos k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) k2 = k2.reshape(k.shape) # 更新kw, *对应位置相乘 k = k * cos_pos + k2 * sin_pos return q, k # %% def attention(q, k, v, mask=None, dropout=None, use_RoPE=True): # q.shape: (bs, head, seq_len, dk) # k.shape: (bs, head, seq_len, dk) # v.shape: (bs, head, seq_len, dk) if use_RoPE: q, k = RoPE(q, k) d_k = k.size()[-1] att_logits = torch.matmul(q, k.transpose(-2, -1)) # (bs, head, seq_len, seq_len) att_logits /= math.sqrt(d_k) if mask is not None: att_scores = att_logits.masked_fill(mask == 0, -1e-9) # mask掉为0的部分,设为负无穷大 att_scores = F.softmax(att_logits, dim=-1) # (bs, head, seq_len, seq_len) if dropout is not None: att_scores = dropout(att_scores) # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk) return torch.matmul(att_scores, v), att_scores if __name__ == '__main__': # (bs, head, seq_len, dk) q = torch.randn((8, 12, 10, 32)) k = torch.randn((8, 12, 10, 32)) v = torch.randn((8, 12, 10, 32)) res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True) # (bs, head, seq_len, dk), (bs, head, seq_len, seq_len) print(res.shape, att_scores.shape)
?
?
?
?
?
?
Reference
- Transformer升级之路:2、博采众长的旋转式位置编码
- 《RoFormer: Enhanced Transformer with Rotary Position Embedding》
- RoPE详细推导版
- Transformer升级之路:6、旋转位置编码的完备性分析
- 让研究人员绞尽脑汁的Transformer位置编码
- Transformer升级之路:4、二维位置的旋转式位置编码