Play with Beam Search

前面写了一篇文章介绍了一下自己一直在维护的APS项目,完稿的时候我还没有公开代码,这是因为当时发现,我在一些public的数据集上训练的ASR模型和paper里已有的结果存在比较大的gap。ASR部分的代码写的时间是比较早的,不过那时都是在实验室的一些内部数据上做训练和测试,一直没有在Librispeech,WSJ这些集子上看看结果,所以此前并没有意识在这个问题。为了找到根源所在,缩小差距,从去年年底开始,我陆续花了一些功夫一一核对了问题可能存在的环节,包括特征,模型,解码,评估等等。虽然最后也没有找到什么明显的bug。但是在这个过程中我把自己一度重点怀疑的环节:解码的逻辑翻来覆去check了好几遍,这篇文章就算做个小总结吧。

正文开始前说几句无关主题的话。由于现在很多做asr的同学都在espnet上做实验,所以带来的一个十分尴尬的问题就是,我遇到的现象常常显得非常小众而他人难以依据他们自己过往的认识和理解(基于espnet)对我的现象和结论给出一些分析。 在这种情况下,我就只能依靠自己去double check自己的逻辑问题,这个显然也不算容易。我常用的方法是,对于某个算法比较确定的认识,我强迫自己用两种不同却等价的方式实现,只要核对结果相同就加以确认;对于不确定的认识,我会写出所有我觉得可能的实现,挑选最好的结果对应的认知和实现方法。不过这都是没有办法的办法了,遇到问题有人能参与积极的讨论还是很有必要的一件事情。

下面进入主题,APS的github地址在aps,感兴趣的同学可以尝试用一下,同时也欢迎review代码,提PR或者issue。

Acceleration

首先声明一下,本文中提到的beam search算法主要针对ASR任务中,基于attention机制的encoder-decoder结构的声学模型的解码,和NLP任务中MT的解码逻辑基本一致,decoder存在自回归逻辑,也称为left-right beam search。在CTC和RNNT的解码中,另有prefix beam search的概念,区别在于过程中我们是否考虑合并相同前缀的路径。CTC和RNNT的建模时blank节点的存在导致不同的前缀序列实际可能对应的是一个相同的解码序列, 因此严谨的做法是在beam search过程中把这些拥有相同解码序列的的beam节点合并掉,否则解码序列会产生冗余(即nbest里面存在相同的序列) 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
gather_hypos = []
active_nodes = [init_node()]
# step by step
for u in range(U):
extention_nodes = []
# extend B nodes for each active node
for node in active_nodes:
extention_nodes += extend_topk(node)
# select top-B nodes for BxB extention_nodes
active_nodes = select_topk(extention_nodes)
for node in active_nodes:
# get decoding sequence from eos node
if is_eos_node(node):
gather_hypos.append(trace_back(node))

beam search的逻辑比较简单,假设beam size为$B$,字典大小为$V$,满足$B \leqslant V$,我简单写了一个伪代码如上所示。对于某个step $u$,我们首先对当前活跃的$B$个节点做拓展产生$B \times B$个后续节点,具体做法是:先拿到每个节点的$V$个概率值(总计$B \times V$个),再对保留每个节点的top $B$得到$B \times B$个后续节点。随后对此处$B \times B$个节点取全局的top $B$个节点作为$u + 1$步的活跃节点。$u = 0$的时候采用<sos>节点初始化活跃节点。beam search进行的过程中,当某个活跃节点上产生<eos>符号的时候,从这个节点回溯得到的序列便形成一条解码序列(可称为hypothesis)。这个过程会自顶向下的维护一个多叉树(叉数为$B$),而最直接的写法也十分类似树的层序遍历,在step $u$,写个for循环为每个活跃(父亲)节点拓展,进行$B$次计算(拿到decoder的概率输出进行的相关计算)。这种情况下,处理$U$步一共需要进行$U \times B$次计算。这种的逻辑实现起来十分直观,不容易出错(子节点的计算所需信息可以直接从父节点拿到),但是缺点也很明显,计算次数较多,体验上就是耗时久,速度慢。所以可以看一下如何去优化这一过程,提升实验效率。

回顾上述算法可以发现,在每个时刻内为活跃节点拓展子节点的操作不存在依赖性,也就是可以并行的,因此可以把这一个时刻内的逻辑:拓展$B \times B$节点和选取top$B$个节点用一次计算完成(增加一个batch维度),这样总体的计算次数被缩小了$B$倍,即$U \times B \to U$。依据这个逻辑,我们甚至可以改写代码[1],对多个句子同时解码。句子数为$N$时,在时刻$t$将拓展$N \times B \times B$个节点的计算一次完成。这种写法需要注意的是如何维护好父子节点之间的index关系,用于整理形成正确的解码序列。我自己的话额外的维护了一个point变量(具体可以参考源代码),用于指示当前$B$维的活跃节点对应的父节点下标。如果解码的时候考虑LM fusion和CTC score,我们依旧可以根据上面的逻辑并行LM和CTC score的计算,详细的代码可以参考aps的实现,我会在下文单独分析的时候再提一下。

继续优化的话,可以考虑$U$这个维度,即进行多少个step的beam search。标准情况下,beam search终止的条件是:当step $u$拓展节点取出的top $B$个节点全部是<eos>,即在当前$B$个活跃节点均拓展出了<eos>且它们的概率最大。一般来说我自己的观测,中文上,这种情况比较常见,英文上则比较困难,且随着beam size $B$的调大,越来越难以满足。因此默认情况下beam search还是有可能跑满预设的step最大值$U$。如何自动的work out出$U$这个值,我们可以看一下如下两种做法。一是设置一个关于eos解码序列(本文也称为hypothesis)总数的阈值$H$,如果产生的hypothesis个数超过了这个值,则自动停止。其二是espnet中采用的[2],如果连续$L$个step里,每个step产生的hypothesis对应的概率值和当前最优的hypothesis的概率值之差超过一个阈值$D$,则自动停止。这两种做法均是基于一个观测,即随着解码步数的增多,产生的hypothesis的概率倾向于越来越小。注意这里只是倾向,不是绝对,因此这些做法实际是有风险丢掉最优解的。做法1的假设是,最优的解码序列最有可能包含在最先的产生的$H$个hypothesis之中。做法2则有点类似WFST解码器的一种减枝方法,当连续$L$步产生的最优hypothesis的概率都足够小于当前已经拿到的最优值时,那么后续也不太可能产生好于当前最优hypothesis的结果。这两种做法基本上都可以自动的决定出$U$的大小,获得不错的加速体验。

解决了“快”这个问题之后,我们需要进一步的向“好”的目标进发,比较常见的做法是beam search的过程中加入LM的分数,一种代表性的做法被称为shallow fusion,这种情况下解码产生的最优序列从全局的角度来说,被很多paper里面表示为下面的式子:

$\lambda$表示语言模型的权重,$\gamma$表示正则项的权重。正则项(很多地方也叫惩罚项)的提出多是解决模型本身,或者加上了LM的分数之后,解码的结果存在一些bias这一现象,下面分两部分讨论。

Shallow Fusion

从实现的角度而言,shallow fusion改动很小。原版beam search在step $u$产生$B \times V$节点的时候,概率值全部用的是decoder产生的声学概率,即

而shallow fusion中使用

替代即可。由于decoder是存在自回归逻辑的,所以$(2)$式中的条件概率同时基于先前的解码序列$\mathbf{y}_{0, \cdots, u-1}$和全部的声学特征$\mathbf{x}$(通过attention机制产生的context传递),因此可以认为$(2)$式的$p_u$并不是纯粹的声学概率,而是融合了一部分LM的信息的,后续有一些文章的idea便是从decoder产生的声学概率中减去一部分LM分数,再和外部LM进行fusion。另外,我也发现了另一种做法的fusion,即$B \times V \to B \times B$节点的过程依旧采用声学得分,在后续筛选$B \times B \to B$的时候再fusion上LM的分数。由于第一阶段的选择没有使用LM,因此过滤掉了一些可能的解码序列,但是也由于AM的约束,一般也不会产生一些十分奇怪的结果。

$(3)$式中唯一的超参是LM权重$\lambda$,这个值我在早期的实验下得到的结论是最优值常在0.2上下浮动,而espnet中的动辄高于0.5的权重结果显然让我感到比较怪异。整体上说,LM产生的后验在方差上一般是比AM要低的,因此对相同的token,LM产生的log概率往往比AM要低不少, 这一点可以从采用相同的字典训练的LM相比AM在开发集上拥有更高的loss值(对应的预测的token准确率更低)看出。所以如果调大LM的权重,beam search的结果倾向于LM主导,解码的结果必定会更加发散,这就是为什么我认为LM的最优权重理论上应该维持在一个比较低的范围内。

首先我需要排除的是LM的问题。关于LM的训练我对比过两种训练方式,BPTT和整句训练。前者在训练中会将前一个batch产生的hidden state传递给后一个batch,每个batch的time step固定。这样的好处是训练阶段显存消耗稳定,但是dataloader设计上需要满足前后batch存在上下文关系,否则传递hidden state没有意义。整句训练是将若干句子包装成minibatch进行训练,由于LM的训练语料是变长的,因此长句子下容易出现OOM问题,一般采用动态的batchsize。在RNN结构下,我在Librispeech上发现BPTT训练在fusion中的效果略好。对比之下,采用整句训练的Transformer LM可以得到比BPTT训练的RNNLM好的结果。不过这两种方法产生的LM并没有在LM weight的规律上产生歧义。

紧接着我开始尝试一些正则方法,包括做score normalization,使用eos threshold [3,4,5],使用length penalty [6],coverage penalty [7,8] 等等,但是这些方法主要解决的是因为解码序列过短或者过长导致的插入/删除错误过多的现象。排除上述问题的可能性之后,我把关注点投向espnet自身的解码逻辑上来,并最终认为LM权重的问题根源极大可能在于解码过程中使用的CTC forward score上。当时我的解码逻辑并没有支持CTC score,因此迫不得已在春节前后花了两周多的时间添加了这部分逻辑进行对比。我的理解是,由于fusion的时候同时考虑了CTC score项,那么原来的$(3)$式需要写为:

把CTC score和原decoder产生的声学分数的加权结果视为新的声学分数(即前两项之和),如果相比原$\log p_{\text{att}}(\mathbf{y}_u|\mathbf{y}_{0, \cdots, u-1})$ ,新的声学分数增加了,那么要想维持原先0.2左右的权重,新的权重也就需要适当的调高,因此达到0.5-0.6的范围内是可能的,具体的取决于CTC score值本身的范围和权重$\gamma$的值。最终在debug完支持CTC score的beam search之后,我也拿到了关于$\gamma, \lambda$诸如(0.2, 0.4),(0.5, 0.5)和(0.4, 0.6)的配置,算是落实了之前的猜想。关于解码过程中CTC score的计算和作用我放在本篇文章的第三部分。

LM fusion容易产生的一个问题是解码序列存在length bias,倾向于产生较短的句子。在第一节中,我提到随着解码步数的增多,产生的hypothesis的概率倾向于越来越小(log分数为负,越加越小),如果不加任何正则,也就越难被选成one-best。可以设想$U = 3$时产生的hypothesis和$U = 30$产生的hypothesis,后者比前者多累加了27个step的AM和LM分数。在单纯依靠AM的情况下,句子较早结束的情况比较少见,因为在alignment未靠近句子结尾的时候,<eos>分数一般较低(反之更容易出现停不下来,即<eos>分数预测过低),不会被选择到下一个step中去,也就不会产生短的hypothesis。fusion上LM score之后,如果LM过早的给出一个较高的<eos>分数,则会对应的产生一个较短的hypothesis,导致length bias的出现。

本节结尾介绍两种不同于正则方法的操作,即score normalization和调整eos threshold。前者指最终对产生的所有hypothesis进行排序的时候,将分数除以序列长度,平均下来使得短句子不再具有优势;后者指的是限制<eos>预测的产生,具体的做法是,设置一个阈值$\phi$,当且仅当<eos>的概率高于非<eos>节点中最大概率的$\phi$倍时才允许被选取到下一个beam search step中去。第一种做法我实际尝试中发现,即使用于无LM fusion的情况下,也是可以提升解码one-best的WER(不大,但是有提升),也就是说,单纯依赖AM的解码结果也是存在一定程度的length bias。

Regularization

最后一节说一下常用的三种正则方法,分别是length penalty,coverage penalty和CTC penalty。我这里把espnet里面用的joint CTC & attention解码方法也认为是一种正则方式,下文会讨论这样看待的原因。关于正则方法,从实现的角度看,我们在使用的时候需要注意的是,每一个beam search step进行还是最后做hypothesis重新排序的时候进行一次即可。

最简单的length penalty作用和第二部分结尾提到的score normalization类似,为了解决fusion之后对短句子的length bias问题,对应的$(1)$式改写为:

由于在每个beam search step内维护的hypothesis长度是相同的,所以只需要在最后进行重排序的时候加上长度惩罚项即可。长句子的惩罚项较大,故而可以纠正短句子上的length bias。MT里面采用的算法不太一样,比如OpenNMT的实现选用的是:$(5 + |\mathbf{y}|)^{\alpha} / (5 + 1)^{\alpha}$,不过初衷都是类似的。

coverage penalty也是在MT和ASR中都被使用的正则方法,对应的$(1)$式改写为:

其中$\mathbf{A}_{\mathbf{y}} \in \R^{T \times U}$表示解码序列$\mathbf{y}$对应的attention权重矩阵。对于coverage惩罚项,我认为可以在每个beam search step都计算一次用于$B \times B \to B$的筛选,因为当前维护的partial hypothesis对应着不同的attention矩阵。但是从实现的角度来说,只在最后一步计算用于重排序更加简单,虽然两者是不等效的。Google的文章[7,8]中可以看到两种计算方式(省掉了下标):

由于attention权重矩阵$\mathbf{A}$是在时间维度过了softmax,因此对于$u \in [0, U - 1]$,满足约束$\sum_t \mathbf{A}_{t, u} \equiv 1$ 。整体看来,coverage计算的是attention权重向量超出某个阈值的总帧数,正常的解码结果权重矩阵是单调的,且在时间轴上的投影的长度会尽可能覆盖掉的活跃的speech的区间,因此满足条件的总帧数多,coverage值较大,而短的解码序列由于结束的过早,满足条件的帧数小。这么看来,coverage penalty实际上描述了一个解码序列对应的alignment的质量的高低,从这个点看,就可以和espnet里面使用的CTC forword score联系起来了。

不同于coverage penalty基于attention权重矩阵去衡量解码序列的对齐质量高低,espnet使用encoder的CTC分支产生的概率计算出的forward score来描述。对于过短或者过长(比如存在重复识别)的句子,CTC层面的对齐质量很差,得分很低使得这些序列对应的节点在beam search的过程中很快会被淘汰。使用forward score时,$(1)$式写为:

实现的时候,需要在每个step计算出拓展当前节点贡献的CTC得分(注意时当前节点不是全部序列,因为CTC的分数需要和decoder的AM分数,LM分数在每个step累加),即$(4)$式中的$p_{\text{ctc}}(\mathbf{y}_u|\mathbf{y}_{0, \cdots, u-1}, \mathbf{x})$,计算方法符号化如下:

forward score的计算在时间轴上存在一个循环(动态规划的逻辑),所以该方法虽然鲁棒性较高,但是计算复杂度也很大(参看[2])。

深入一些,使用CTC score做LM fusion的时候,fusion的逻辑采用的时我在第二部分第一段介绍的后一种方法。用语言描述就是,先使用AM的分数拓展出$B \times V \to B \times B_{\text{ctc}}$个节点($B_\text{ctc} = \alpha B, \alpha \geqslant 1$),然后在这$B \times B_{\text{ctc}}$节点上fusion CTC的分数和LM的分数,最后选取出top $B$个几点到下一个step,即$B \times B_{\text{ctc}} \to B$。优化上依旧可以采用第一部分的逻辑,将$B \times B_{\text{ctc}}$个节点的CTC分数计算放在一个batch里面得到,这部分想了解细节的同学可以参看一下aps的源码。当然也可以沿用原始fusion的逻辑,但是计算量增大很多(取决于字典$V$的大小):需要为$B \times V$个节点计算CTC的分数。解码加上CTC score在一般的数据集上都会带来一些提升(aishell1和librispeech上都是0.几个点),但是不算明显(数据比较标准,集外数据上的提升作用可能会放大),从这点上看也符合正则项带来的效果。

在分析WER的时候,除了看总体的数值之外,还应该参考一下对应的删除,插入,替换错误。一般而言,替换错误的占比很高,而删除和插入错误的分布大致相等,如果统计下来,出现了明显的高插入错误或者高删除错误,对应的往往是句子过短和句子过长。句子过短可以采用上述的五种方法进行调整(本节的三个正则加上第二节的score normalization和eos threshold),句子过长一般是出现了<eos>预测不出导致的序列循环,停不下来的现象。对于这一问题,上面的方法都很难完美的解决,往往是声学模型本身出了一些问题,可以从调整网络结构的角度尝试办法。

Conclusion

本篇文章分三个部分写了一些关于beam search的内容,主要基于自己在过去的一段时间内对解码进行优化和调试的一些经验和感悟。后续我会找时间在标准数据集上整理一些比较理想的结果,可以让读者直观的体会一下LM fusion和正则项对识别结果的提升作用。

Reference

[1]. Seki H, Hori T, Watanabe S, et al. Vectorized Beam Search for CTC-Attention-Based Speech Recognition[C]//INTERSPEECH. 2019: 3825-3829.
[2]. Watanabe S, Hori T, Kim S, et al. Hybrid CTC/attention architecture for end-to-end speech recognition[J]. IEEE Journal of Selected Topics in Signal Processing, 2017, 11(8): 1240-1253.
[3]. Zeyer A, Bahar P, Irie K, et al. A comparison of transformer and lstm encoder decoder models for asr[C]//2019 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU). IEEE, 2019: 8-15.
[4]. Irie K, Prabhavalkar R, Kannan A, et al. On the choice of modeling unit for sequence-to-sequence speech recognition[J]. arXiv preprint arXiv:1902.01955, 2019.
[5]. Hannun A, Lee A, Xu Q, et al. Sequence-to-sequence speech recognition with time-depth separable convolutions[J]. arXiv preprint arXiv:1904.02619, 2019.
[6]. Bahdanau D, Chorowski J, Serdyuk D, et al. End-to-end attention-based large vocabulary speech recognition[C]//2016 IEEE international conference on acoustics, speech and signal processing (ICASSP). IEEE, 2016: 4945-4949.
[7]. Chorowski J, Jaitly N. Towards better decoding and language model integration in sequence to sequence models[J]. arXiv preprint arXiv:1612.02695, 2016.
[8]. Kannan A, Wu Y, Nguyen P, et al. An analysis of incorporating an external language model into a sequence-to-sequence model[C]//2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018: 1-5828.