Rotary Position Embedding (RoPE, 旋转式位置编码) | 原理讲解+torch代码实现

  • ?? RoPE为苏剑林大佬之作,最早应用于他自研的RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》

  • ?? 据我了解,最近发布的大语言模型:Meta的LLaMA、清华的ChatGLM都采用了RoPE。这也足以证明了RoPE的优势。

  • ?? 本文讲解下个人对RoPE原理的理解以及自己用torch复现了一下,更详细地请参阅苏神的原文(文末已附上链接)。

  • ?? 如对RoPE公式推导有任何疑问,可评论区或私信反馈,我将做出详细解答。

文章目录

  • 1、RoPE 动机
    • 1.1、绝对位置编码
    • 1.2、相对位置编码
    • 1.3、RoPE
  • 2、RoPE 原理
    • 2.1、将待解问题公式化(提出假设)
    • 2.2、推导求解
    • 2.3、RoPE的编码形式
  • 3、RoPE 代码实现(torch版)
  • Reference

1、RoPE 动机

1.1、绝对位置编码

  • 最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。

  • 绝对位置编码的讲解可看我的博客:随记·手撕coding | absolute positional embedding

  • 优势: 实现简单,可预先计算好,不用参与训练,速度快。

  • 劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。

1.2、相对位置编码

  • 经典相对位置编码RPR式的讲解可看我的博客:相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记 【在k, v中注入相对位置信息】
  • 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。

1.3、RoPE

  • RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。
  • 主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了。

2、RoPE 原理

? 那rope是怎么在q,k中注入这种相对位置信息的呢?我看了苏神的推导。大概是这样的:先假设q,k是二维的情形,因为复数可用二维向量表示,所以借助复数域来求解。在推导的过程中,用的最多的一句话就是:“为简单起见,假设xxx” 这对推导十分关键。

  • 有关复数相关基础知识可看这:数学 | 复数的代数、向量、矩阵、极坐标、指数形式 | 复数相乘的物理意义【旋转+缩放】

2.1、将待解问题公式化(提出假设)

首先,假设新的qk向量(即假设已注入绝对位置信息)的内积会引入相对位置信息。并在最后假设合理的初始化条件:
在这里插入图片描述

2.2、推导求解

不是一般性,考虑其q,k向量为二维的情形,借助复数域推导出为q,k向量编码绝对位置信息的函数 f 。
在这里插入图片描述

别看公式多,理解起来并不难。下面我细说一下其中几个关键的推导步骤:

  • 式(8) 的推导:
    在这里插入图片描述

2.3、RoPE的编码形式

上面我们设了q,k的绝对位置编码函数为:
在这里插入图片描述
然后又求出了:在这里插入图片描述
而:
在这里插入图片描述
那带入(4)式就可以得出q,k的绝对位置编码函数了(下面以q为例,k同理)
在这里插入图片描述
为避免这个正交矩阵过于稀疏,浪费算力,代码实现时都是依据下面公式来计算RoPE:
在这里插入图片描述
注:苏神在θ的选择上沿用了tansformer的θi = 10000-2i/d 。因为苏神实验发现,在RoPE中采用这个θ也可以带来一定的远程衰减性(意思就是token之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的θ也可,只要满足远程衰减。

3、RoPE 代码实现(torch版)

  • 代码实现基于torch,代码中也写好详细注释。如有错误,评论区或私信我反馈,谢谢~
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# %%

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复

    # (bs, head, max_len, output_dim)
    # reshape后就是:偶数sin, 奇数cos了
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings


# %%

def RoPE(q, k):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)


    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了



    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k


# %%

def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)

    if use_RoPE:
        q, k = RoPE(q, k)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_scores = att_logits.masked_fill(mask == 0, -1e-9)  # mask掉为0的部分,设为负无穷大

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)


    # (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
    print(res.shape, att_scores.shape)


?
?
?
?
?
?

Reference

  • Transformer升级之路:2、博采众长的旋转式位置编码
  • 《RoFormer: Enhanced Transformer with Rotary Position Embedding》
  • RoPE详细推导版
  • Transformer升级之路:6、旋转位置编码的完备性分析
  • 让研究人员绞尽脑汁的Transformer位置编码
  • Transformer升级之路:4、二维位置的旋转式位置编码