使用相对位置编码下MHSA的计算

最近几年语音的各项任务上也开始用上transformer的模型了,前段时间大家在谈论conformer的时候,我看了一下Transformer-XL里面提出的相对位置编码表示方法。原论文我在开始看的时候感觉符号层面有些混乱,向量和矩阵容易有些分不清,因此自己在后续实现的时候花了一些功夫思考和整理。本篇文章用于记录一下我自己对使用相对位置编码时,self-attention计算的相关分析和理解。

首先做一下符号化,本文统一用大写表示矩阵,小写表示向量,为了简便起见均没有加粗,向量默认遵从列向量规范。query,key,value对应的投影(变换)矩阵定义为$W^Q, W^K, W^V$,通过线性运算(矩阵乘法)得到参与dot attention计算的$Q, K, V$。在self-attention(也就是encoder里面的一般用法)里面,query,key和value值是相同的,这里用$X$统一表示,即:

其中$X\in \mathbb{R}^{T \times D_m}$,$W^Q, W^K, W^V \in \mathbb{R}^{D_m \times D_k}$,$T$表示序列长度。带入dot attention的计算公式:

得到

$(2)$式中,$QK^T \in \mathbb{R}^{T \times T}$,$C \in \mathbb{R}^{T \times D_m}$,用$C$表示此项结果是含有attention context的意思。上面为了符号简便,没有考虑multi-head的情况,不过这个并不影响下文的分析和推理。以前speech里面会把取softmax之后结果称之为alignment或者attention weight,相对位置编码的改动主要也是集中在此。现在单独拿出

考虑。使用绝对位置编码时,一般的做法是用学习到的或者固定的位置编码与原始embedding$X$相叠加作为输入,令绝对位置编码矩阵为$P \in \mathbb{R}^{T \times D_m}$。先看一下Transformer-XL中的分析。它是基于用$X = X + P$带入$(4)$式的结果进行分析:

这里个人觉得单做分析使用问题不大,实际使用transformer时候,考虑到layernorm和深层网络等因素,理论上并不完全如此(不严格等价),不过不影响本文对使用相对位置编码下计算的分析过程。令$W = W^Q (W^K)^T$, 对$(5)$中的$S$展开有:

相对和绝对位置编码的区别在于前者只对相对位置进行考虑,符号化来看,对绝对位置编码矩阵$P$,分解为:

其中$p_i \in \mathbb{R}^{D_m \times 1}$为列向量。对于相对位置编码矩阵$R$,在帧长为$T$的序列下,一共存在$2T - 1$个相对位置,若不考虑边界约束,则$R \in \mathbb{R}^{2T-1 \times D_m}$:

其中$r_i \in \mathbb{R}^{D_m \times 1}$同样为列向量。考虑边界约束,比如最大半径为$Z$,则$R \in \mathbb{R}^{2Z - 1 \times D_m}$,本文考虑前者进行分析。现在来看,我们是不能直接将$(6)$式中的$P$替换为$R$进行计算的,我把第二项拿出来单独考虑,令$M_p = XWP^T \in \mathbb{R}^{T \times T} $,展开:

当使用相对位置编码时,我们希望的$M_r \in \mathbb{R}^{T \times T}$计算结果应该是

如果直接替换$P$为$R$,得到的矩阵$\hat{M_r} = XWR^T \in \mathbb{R}^{T \times 2T - 1} $为:

对比$(10)$和$(11)$式中的$M_r$和$\hat{M_r}$不难发现二者之间的关系,即:

因此从计算的角度来说,我们可以先拿到$\hat{M_r}$,再调整得到$M_r$,这也就是Transformer-XL附录B想要说明的东西。不过$(12)$的结论是在不考虑相对位置编码的边界约束下得到的,如果相对半径$Z < T$,则:

$(10)$式中$M_r \in \mathbb{R}^{T \times T}$改写为

$(11)$式中$\hat{M_r} = XWR^T \in \mathbb{R}^{T \times 2Z - 1} $改写为

此时二者的关系为:

其中$b_i = \max(0, i - Z + 1)$,其余位置的值为边界值依次重复扩展。在这种情况下,一种比较方便的方法是生成相对位置编码的时候按照$R \in \mathbb{R}^{2T-1 \times D_m}$的大小进行生成,即:

这样之后的计算结果结论便和无边界约束的$(12)$保持一致。

现在再和”Self-Attention with Relative Position Representations”一文中的做法做一下联系。它对$(4)$式的拆解为:

即只对key部分叠加了相对位置信息。其中对第二项$XW^QR^T$的计算同样可以按照$(17)$和$(12)$的结论进行。由此便将两种使用相对位置编码的self-attention在计算方法和逻辑上得到了统一。