Fastformer: Additive Attention Can Be All You Need

前言

本文要是对《Fastformer: Additive Attention Can Be All You Need》这篇论文的一个解读与总结,原文链接[2108.09084] Fastformer: Additive Attention Can Be All You Need (arxiv.org)               

本文提出一种新型Transformer模型,它被设计用来降低transformer的时间复杂度,同时优化transformer处理长文本的性能效率。Fastformer采用了加性注意力机制来对全局上下文进行建模,然后根据其与全局上下文表示的交互进一步转换每个token表示。它不仅实现了具有线性复杂性的有效上下文建模,而且经测试它比现已有的许多Transformer模型的效率都更高,同时可以实现更好的长文本建模能力。

现已有研究存在的问题

点积自注意力机制的限制:由于self-attention计算每对位置的输入表示之间的点积,因此其时间复杂度是输入序列长度的二次方,这导致普通的Transformer模型很难有效的处理长文本输入序列。

本文中例举了很多Transformer模型:                                                                                                       ① BigBird:计算稀疏注意力代替密集注意力。它使用局部注意力、特定位置的全局注意力和特定数量token之间的随机注意力的组合。然而,稀疏注意力通常不能完全模拟全局上下文。                 ② Linformer:通过计算近似矩阵,利用了自注意矩阵的低秩特性。它将注意力关键字和值投影到与序列长度无关的低维矩阵中。然而,这种近似实际上是上下文不可知的,这可能会削弱Transformer的上下文建模能力。                                                                                                               ③ Longformer:结合了滑动窗口注意力和全局注意力来模拟局部和全局上下文。然而它的计算复杂度仍然很高,同时参数量较大,性能依赖全局注意力。                                                                 ④ Linear Transformer:使用核函数近似自注意机制的线性复杂度变换器。然而它很依赖核函数的选择,同时会受核函数的限制,训练过程中可能出现不稳定的问题。                                               ⑤ Poolingformer:它首先使用滑动窗口自我注意力来捕捉短距离上下文,然后使用集中自我注意力来捕获长距离上下文。然而它可能会导致信息丢失,因为池化会丢失部分信息,对序列顺序敏感,因为此话会破快序列顺序信息,同时无法提供足够的上下文信息,因为池化的局限。                   此外,当输入序列长度很长时,这些方法都不够有效。

加性注意力机制                                                                              

首先使用三个独立的线性变换层将输入转换成注意力查询,键,值矩阵  Q,K,Vin mathbb{R}_{}^{N	imes d}将它们分别写成 Q= left [ q_{1}, q_{2}cdots , q_{N} 
ight ]K= left [ k_{1}, k_{2}cdots , k_{N} 
ight ]V= left [ v_{1},v_{2}cdots , v_{N} 
ight ]         

  (1)我们首先使用加性注意将查询矩阵概括为全局查询向量q。                                                                 ① 先计算查询矩阵Q每一列的注意力权重 a_{i} :                                                                                                                              a_{i}= frac{exp left (w _{q}^{T}q_{i/sqrt{d}} 
ight )}{sum_{j= 1}^{N}exp left (w _{q}^{T}q_{j/sqrt{d}} 
ight )}                                                                       ② 将对应权重 a_{i} 与列向量 q_{i} 相乘求和,得到全局查询向量 q:                                                                                                   q= sum_{i=1}^{N}a_{i}q_{i}   

(2)我们通过元素乘积将全局查询向量与每个键向量相结合,以学习全局上下文感知键矩阵,并通过加性注意力将其进一步总结为全局键向量 k。                                                                                       ① 通过将全局查询向量 q 与键矩阵 K 的每一列逐位相乘:                                                                                                           p_{i}= qast k_{i}                                                                                           ② 计算全局上下文感知键矩阵每一列的注意力权重 eta _{i} :                                                                                                                eta _{i}= frac{exp left (w _{k}^{T}p_{i/sqrt{d}} 
ight )}{sum_{j= 1}^{N}exp left (w _{k}^{T}p_{j/sqrt{d}} 
ight )}                                                                      ③ 将对应权重 eta _{i} 与列向量 p_{i} 相乘求和,得到全局键向量 k :                                                                                                       k= sum_{i=1}^{N}eta _{i}p_{i} 

(3)最后,我们对全局上下文感知键和值之间的交互进行建模,以学习全局上下文感知注意力值,并将其与查询进一步组合以形成最终输出。                                                                                          ① 通过全局键向量 k 和值矩阵 V 的每一个列向量逐位相乘得到键-值交互向量 u_{i} :                                                                        u_{i}= kast v_{i}                                                                                        ② 将线性变换层用于每个键-值交互向量 u_{i} ,以学习其隐藏表示得到 r_{i} ,该层输出矩阵表示为 R= left [ r_{1} ,r_{2},cdots ,r_{N}
ight ] in mathbb{R}^{N	imes d} :                                                                                                                                                 r_{i}= Transformationleft ( u_{i} 
ight )                                                             ③  最后将学习到的隐藏表示 r_i  与查询矩阵的每一列 q_{i} 相加得到最终结果 Output :                                                                 Output= q+r 

复杂度与参数量分析

(1)时间复杂度                                                                                                                                         ① Fastformer的时间复杂度为: Oleft ( Ncdot d 
ight )            

        ② Transformer的时间复杂度为: Oleft ( N^{2}cdot d
ight )

(2)总参数量                                                                                                                                            ① Fastformer的空间复杂度为: 3hd^{2}+2hd  。其中 3hd^{2} 表示有 h 个关注头,每个关注头中的查询、键和值转换所需要的参数,就是 W_{q}^{T},W_{k}^{T},W_{v}^{T} ,它们的大小为:d^{2},    将 Input 转换成 Q,K,V就是 3d^{2}  ,h 个关注头就是 3hd^{2}。其中 2hd 表示全局查询 q 和全局键 k 的参数量为 2d , h 个关注头就是 2hd ,所以总的就是 3hd^{2}+2hd 。

        ② Transformer的总参数量至少为: 4hd^{2} ,准确来说的总参数量为: 3hd^{2}+hNd,因为一般 N> d 所以总参数量至少为 4hd^{2} 。其中 3hd^{2} 与上面一样,是 W_{q}^{T},W_{k}^{T},W_{v}^{T} 的参数量。其中 hNd 表示  Softmaxleft ( QK^{T} 
ight )  所得到的注意力权重矩阵,最后与值矩阵相乘。

最后作者在多个任务以及多个数据集上验证了Fastformer的效率确实优于其他Tranformer模型,同时时间复杂度也比其他的Transformer模型更低。

作者还研究了参数共享可以提高模型性能的同时减小模型参数大小,经对比发现Query-Value Sharing+Head-wise Sharing效果最好,不同层之间参数共享可以减轻过拟合的风险,同时Head-wise Sharing可能会导致降低模型效果,因为不同注意力头希望捕获到不同的上下文模式,使用同样的注意力头参数可能会导致反效果。

以上就是我对这篇文章的理解,如有不当欢迎指正!!!