BPTT of RNN

我在这一篇里推导一下RNN的反向传播算法,尝试从另一个角度来理解在RNN在时间轴上的依赖关系,由此导出所谓的梯度消失问题,即原始的RNN为什么难以训练。

首先定义RNN的前向过程如下:

网上常见的解释是,对于RNN,在$t$时刻的输入不仅可以直接影响当前时刻的输出$\mathbf{h}_t$,而且可以通过$\mathbf{h}_t$对下一时刻的输出$\mathbf{h}_{t + 1}$产生影响,因此计算梯度的时候,需将$t + 1$时刻的误差也传递到当前时刻。个人觉得这种解释没有真实的反映出时间序列上的依赖关系,因为$t$时刻输入会对后续时刻的整个序列的输出产生影响,不仅仅是下一个时刻的输出,因此,从这里理解,$t$时刻的误差应该是后续时间步上误差传递到当前时刻值的总和,也正是由于这种在时间轴上从后向前的传递,导致了在时间步(time step)变大时,靠后时刻的误差不能传递到序列的早期时刻,引发了所谓的梯度消失问题(长期依赖问题)。下面对这段文字描述做数学推导。

首先定义序列时长为$T$(即time_step = $T$)。假设这是一个多层的many-to-many的RNN,对于其中某一层,输入和输出表示为:

在每个时刻,可以获得由上一层网络的传递得到误差项(实际为梯度,这里为了方便),表示为

定义$\delta_k$:

那么$\delta_k$就表示在当前输入序列下,$k \leqslant t \leqslant T$时刻范围内对$\mathbf{z}_k$的梯度之和,这也是我们最终要求的真实梯度,即在当前输入序列下,可以计算出的误差对$\mathbf{z}_k$的梯度。

再看一下$\delta_k^t$:表示当前输出的误差对$\mathbf{z}_k$的梯度,在feed forward的网络结构中,这里$\delta_k^t = \delta_k$。

根据$\delta_k^t$的性质:

结合$\delta_k^t​$和$\delta_k​$,可以得到$\delta_k​$的表达式:

从上面的这个式子可知,如果要计算时刻$k$的梯度,那么需要两层循环,内循环处理连乘,外循环处理加和,显然,实际的bptt实现并没有这个复杂,因为$\delta_k$可以用更简洁的形式表达。下面导出这种具体形式。

首先$\delta_k^t$和$\delta_{k + 1}^t$的关系为:

根据$\delta_{k + 1}$的表达式,对其右乘$\frac{\partial \mathbf{z}_{k + 1}}{\partial \mathbf{z}_k}$

因此得到了递推式:

上式就是RNN的BPTT在实现时的逻辑,从这里可以更加容易看出时间依赖的本质,注意$\delta_{k + 1} \ne \delta_{k + 1}^{k + 1}$,$\delta_k$表示的是在时刻$t \in [k + 1, T]$,对时刻$\mathbf{z}_{k + 1}$梯度的累积量,而$\delta_{k + 1}^{k + 1}$仅仅表示$\ell_{k + 1}对$$\mathbf{z}_{k + 1}$的梯度。

因此本质上不能理解成$t$时刻输入$\mathbf{x}_t$仅仅会对$\mathbf{h}_t, \mathbf{h}_{t + 1}$产生影响,就此在反向传播的时候将两个时刻产生的梯度相加,因为时间依赖的本质是可以无限延伸的,即沿着时间轴从$T \to 1$传播,这也就是所谓TT:through time的含义。

这里还要提一下RNN的几种常见结构,many-to-one, one-to-one, one-to-many, 和many-to-many,我上面的分析是针对最后一种即从上层网络传递下来的误差向量维度和输入维度保持一致。前面三种现在理解起来也十分容易了。many-to-one就是要将最后时刻(一般是这样)的梯度传递到输入的每个时刻,one-to-many就是将每个时刻的梯度传递到起始时刻,one-to-one只需要将最后时刻传递到起始时刻就行了,即$\delta_1^T$。

再以one-to-one分析一下梯度消失问题,根据上面的解释,可以写出$\delta_1^T$:

其中

当时间步$T$过大时,矩阵和激活值的连乘很容易导致$\delta_1^T$过大或者过小。过大的解决方法比较简单,目前都采用梯度裁剪的方法(clip gradient)。过小就是所谓的梯度消失问题,RNN无法从训练中获得远距离梯度的更新,也就无法学习到远距离的依赖关系,解决的方法可以是重新设计循环结构,合理初始化权值以及使用合适的激活函数等等。