LSTM希望通过改进的RNN内部计算方法来应对普通RNN经常面临的梯度消失和梯度爆炸。基本思路是通过改变逆向传播求导时单纯的偏导连乘关系,从而避免较小的sigmoid或relu激活函数偏导连乘现象。
RNN网络unfold以后,将按时间t展开为若干个结构相同的计算单元,每个计算单元在利用当前时间的输入以外,还需要之前时间的输出。以下将展示每个计算单元的内部计算流程,假设当前的计算单元对应时间为t。
每个计算单元内由input gate,forget gate和output gate三个“闸门”结构依先后顺序构成。在每一个gate内部,相关的输入都匹配专门的权重矩阵,各个输入相加后都将匹配专门的bias向量,总体求和后需要通过专门的激活函数进行处理形成输出。 设定当期(即t期)输入为
,前一期输出为
。
input gate
input gate实际上是类似于一个filter,即用sigmoid激活函数的激活值过滤或加权实际的input。实际的input为:
sigmoid激活函数filter为:
input gate层的最终输出就是
与
的点乘,即元素层面的对应相乘。
inner state
LSTM较于普通RNN网络增加了一个内部状态量 . 记忆的控制就是通过forget gate对于 的过滤而发挥作用。
forget gate
与input gate相同,forget gate也是一个sigmoid激活函数激活值形成的filter,用于对上一期的状态量
进行过滤。
当期的状态量
就是input gate层的输出值与IG过滤后的上一期状态量的简单相加的结果。注意这里的操作仅为简单的相加,并没有加入权重,不存在相乘,也没有使用新的激活函数,这一步骤是消除RNN反向传播网络梯度消失或梯度爆炸的关键:
output gate
同之前的两个gate类似,output gate也是一个sigmoid激活函数filter,对当期的状态量
进行过滤。
在接受过滤前,先使用tanh激活函数进行区间压缩:
以此对压缩后的
进行过滤,形成最终当期计算单元的最终输出:
和
将可用于下一期(t+1)计算单元的内部计算。