使用相对位置编码下MHSA的计算

最近几年语音的各项任务上也开始用上transformer的模型了,前段时间大家在谈论conformer的时候,我看了一下Transformer-XL里面提出的相对位置编码表示方法。原论文我在开始看的时候感觉符号层面有些混乱,向量和矩阵容易有些分不清,因此自己在后续实现的时候花了一些功夫思考和整理。本篇文章用于记录一下我自己对使用相对位置编码时,self-attention计算的相关分析和理解。

首先做一下符号化,本文统一用大写表示矩阵,小写表示向量,为了简便起见均没有加粗,向量默认遵从列向量规范。query,key,value对应的投影(变换)矩阵定义为WQ,WK,WV,通过线性运算(矩阵乘法)得到参与dot attention计算的Q,K,V。在self-attention(也就是encoder里面的一般用法)里面,query,key和value值是相同的,这里用X统一表示,即:

(1)Q=XWQ,K=XWK,V=XWV

其中XRT×Dm,WQ,WK,WVRDm×DkT表示序列长度。带入dot attention的计算公式:

(2)C=softmax(QKTDk)V

得到

(3)C=softmax(XWQ(XWK)TDk)VWV

(2)式中,QKTRT×TCRT×Dm,用C表示此项结果是含有attention context的意思。上面为了符号简便,没有考虑multi-head的情况,不过这个并不影响下文的分析和推理。以前speech里面会把取softmax之后结果称之为alignment或者attention weight,相对位置编码的改动主要也是集中在此。现在单独拿出

(4)S=XWQ(XWK)T=XWQ(WK)TXT

考虑。使用绝对位置编码时,一般的做法是用学习到的或者固定的位置编码与原始embeddingX相叠加作为输入,令绝对位置编码矩阵为PRT×Dm。先看一下Transformer-XL中的分析。它是基于用X=X+P带入(4)式的结果进行分析:

(5)S=(X+P)WQ(WK)T(X+P)T

这里个人觉得单做分析使用问题不大,实际使用transformer时候,考虑到layernorm和深层网络等因素,理论上并不完全如此(不严格等价),不过不影响本文对使用相对位置编码下计算的分析过程。令W=WQ(WK)T, 对(5)中的S展开有:

(6)S=XWXT+XWPT+PWXT+PWPT

相对和绝对位置编码的区别在于前者只对相对位置进行考虑,符号化来看,对绝对位置编码矩阵P,分解为:

(7)P=[p0Tp1TpT1T]

其中piRDm×1为列向量。对于相对位置编码矩阵R,在帧长为T的序列下,一共存在2T1个相对位置,若不考虑边界约束,则RR2T1×Dm

(8)R=[r0Tr1Tr2T1T]or[rT+1TrT+2TrT1T]

其中riRDm×1同样为列向量。考虑边界约束,比如最大半径为Z,则RR2Z1×Dm,本文考虑前者进行分析。现在来看,我们是不能直接将(6)式中的P替换为R进行计算的,我把第二项拿出来单独考虑,令Mp=XWPTRT×T,展开:

(9)Mp=[x0Tx1TxT1T]W[p0p1pT1]=[x0TWp0x0TWp1x0TWpT1x1TWp0x1TWp1x1TWpT1xT1TWp0xT1TWp1xT1TWpT1]

当使用相对位置编码时,我们希望的MrRT×T计算结果应该是

(10)Mr=[x0TWr0x0TWr1x0TWr2x0TWrT1x1TWr1x1TWr0x1TWr1x1TWrT2x2TWr2x1TWr1x1TWr0x1TWrT3xT1TWrT+1xT1TWrT+2xT1TWrT+3xT1TWr0]

如果直接替换PR,得到的矩阵Mr^=XWRTRT×2T1为:

(11)Mr^=[x0TWrT+1x0TWr0x0TWr1x0TWrT1x1TWrT+1x1TWr0x1TWr1x1TWrT1x2TWrT+1x2TWr0x2TWr1x2TWrT1xT1TWrT+1xT1TWr0xT1TWr1xT1TWrT1]

对比(10)(11)式中的MrMr^不难发现二者之间的关系,即:

(12)Mr[i]=Mr^[i,T1i:2T1i]

因此从计算的角度来说,我们可以先拿到Mr^,再调整得到Mr,这也就是Transformer-XL附录B想要说明的东西。不过(12)的结论是在不考虑相对位置编码的边界约束下得到的,如果相对半径Z<T,则:

(13)R=[r0Tr1Tr2Z1T]or[rZ+1TrZ+2TrZ1T]

(10)式中MrRT×T改写为

(14)Mr=[x0TWr0x0TWr1x0TWr2x0TWrZ2x0TWrZ1x0TWrZ1x1TWr1x1TWr0x1TWr1x1TWrZ1x1TWrZ1x1TWrZ1xZ1TWrZ+1xZ1TWrZ+2xZ1TWrZ+3xZ1TWrZ1xZ1TWrZ1xZ1TWrZ1xT1TWrZ+1xT1TWrZ+1xT1TWrZ+1xT1TWrZ1xT1TWrZ1xT1TWrZ1]

(11)式中Mr^=XWRTRT×2Z1改写为

(15)Mr^=[x0TWrZ+1x0TWr0x0TWr1x0TWrZ1x1TWrZ+1x1TWr0x1TWr1x1TWrZ1x2TWrZ+1x2TWr0x2TWr1x2TWrZ1xT1TWrZ+1xT1TWr0xT1TWr1xT1TWrZ1]

此时二者的关系为:

(16)Mr[i,bi:bi+Z]=Mr^[i,Z1i:2Z1i]

其中bi=max(0,iZ+1),其余位置的值为边界值依次重复扩展。在这种情况下,一种比较方便的方法是生成相对位置编码的时候按照RR2T1×Dm的大小进行生成,即:

(17)R=[r0Tr0Tr1Tr2Z2Tr2T1T]or[rZ+1TrZ+1TrZ+2TrZ1Tr2T1T]

这样之后的计算结果结论便和无边界约束的(12)保持一致。

现在再和”Self-Attention with Relative Position Representations”一文中的做法做一下联系。它对(4)式的拆解为:

(18)S=XWQ(XWK+R)T=QKT+XWQRT

即只对key部分叠加了相对位置信息。其中对第二项XWQRT的计算同样可以按照(17)(12)的结论进行。由此便将两种使用相对位置编码的self-attention在计算方法和逻辑上得到了统一。