概述
通过前一节对循环神经网络RNN的了解,简单的RNN虽然能够解决长期依赖问题,但是训练和优化比较困难,然后长短时记忆模型LSTM很大程度上解决长期依赖问题,本文主要介绍
1.LSTM的提出
2.LSTM网络结构
3.LSTM的分析
LSTM的提出
早在94年Hochreiter发现了RNN训练过程中的梯度消失和爆炸问题,然后在99年提出LSTM解决该问题。
梯度消失问题的原因可以参考之前的介绍。
常量错误传播
RNN难训练的主要原因在后向传播过程中,梯度随着时间序列的增加而逐渐消失。如果误差能够不消减的进行传递,则可以避免训练难得问题。
常量错误传播-直观想法
假设隐藏层只有一个节点j,则该节点误差计算过程为
如果想做到常误差传播,则需要
此时可以近似无限长时间序列,但是网络过于简单并且实现比较复杂。
常量错误木马(Constant Error Carousel-CEC)
LSTM也是根据CEC演化而来。
针对上面的必要条件
此时激活函数必须是线性并且激活值保持为常量。
作者将满足这个等式的传播称之为 CEC。
当然如果满足CEC的约束条件能够进行常误差传播,但是该网络结构过于简单,同时有两个相关问题需要解决:
1.输入权重冲突:隐藏层节点不仅要保存历史信息,还要对当前的输入进行响应。
2.输出权重冲突:同理隐藏层节点不仅保存历史信息,还要响应该节点到输出层的反馈。
因此隐藏层不仅保存历史信息还要对输入和输出进行不同程度的响应,会导致相关的权重更新冲突。这也是LSTM加入门节点的重要原因之一。
LSTM原型
为了解决1)常量误差传播2)冲突问题,作者通过引入内存单元和门单元重新构建了网络结构。
输入门:能够保护内存单元的内容不受不相关输入的影响,最大程度保存有用信息。
输出门:能够保护其他单元不受不相关内存的影响。网络拓扑如下:
1.y_in表示输入门 ,y_out表示输出门。门在网络结构中相当于新引入隐藏节点,不过激活函数一般选取为sigmoid,将权重压缩到0-1之间
2.内存单元状态sc 用于存储历史信息和当前输入。
3.内存块,这也是该模型长短时记忆由来的原因。在一个内存块中可以有多个内存单元,他们有着相同的输入和输出门。引入内存块能够减少门的个数,加快计算效率。
引入遗忘门
为了解决状态不断累加问题,引入遗忘门对状态进行缩减。网络结构如下:
引入Peephole Connection
为了学习到更精确的时序,引入Peephole Connection,使得各种门的状态能够学习更准确。这也是目前使用最多的网络结构。
LSTM经典网络结构
目前使用较多的网络结构如下图所示:
1.上图展示的是一个内存块并且内存块中仅有一个内存单元
2.网络结构中包括输入、输出和遗忘门;Cell为内存状态单元;虚线表示Peephole连接。
LSTM展开图
下图更形象的展示了LSTM的网络结构,隐藏层有两个内存块,每个内存块有2个单元。
LSTM数学表达
符号定义
I 表示输入集合
H 表示隐藏层节点个数
C为内存块中细胞单元个数
K表示输出层节点个数
γ,ϕ,ω分布表示输入门、遗忘门以及输出门
S表示内存单元状态
前向遍历过程
某个内存块计算过程如下:
- 输入门表达
atγbtγ=∑i=1Iwi,γxti+∑h=1Hwh,γbt−1h+∑c=1Cwc,γst−1c=f(atγ) - 遗忘门表达
atϕbtϕ=∑i=1Iwi,ϕxti+∑h=1Hwh,ϕbt−1h+∑c=1Cwc,ϕst−1c=f(atϕ) - 细胞单元表达
atcstc=∑i=1Iwi,cxti+∑h=1Hwh,cbt−1h=btϕst−1c+btγg(atc) - 输出门表达
atωbtω=∑i=1Iwi,ωxti+∑h=1Hwh,ωbt−1h+∑c=1Cwc,ωstc=f(atω) - 内存单元输出
btc=btωh(stc)
后向遍历过程
后向遍历过程,类似于RNN,需要计算得到各个节点的误差,然后传递到相关的权重。
符号定义
ϵtc=∂L∂btc 表示细胞单元输出的误差,即到达btc 时的汇总误差
ϵts=∂L∂stc 表示中间状态的误差
δtk 表示输出层节点的误差
1.细胞节点误差
2.输出门误差
3.细胞状态误差
4.细胞节点的误差
5.遗忘门误差
6.输入门误差
LSTM分析
优势
1.在一个内存块中LSTM能够做到常误差传递(通过后向传递过程可以得到);以及在不同时序阶段中通过输入门和遗忘门的控制能够使得误差极大化向后传递;最终使得LSTM能够解决较长的依赖问题
2.LSTM能够处理噪声、连续型输入以及分布式表示等问题
3.LSTM泛化能力较其他模型强
4.不需要细粒度参数调优,而且计算复杂度和RNN一致
局限
1.计算复杂度线性增加
2.RNN加上相关优化算法能够解决100以内依赖问题,LSTM能够处理1000以内依赖问题;对于更长的依赖LSTM也会遇到梯度相关问题。
总结
通过本文学习能够非常清楚LSTM的网络结构以及求解算法、理解LSTM内存块内的常量误差传递。