NLP 相关算法 LSTM 算法流程

LSTM希望通过改进的RNN内部计算方法来应对普通RNN经常面临的梯度消失和梯度爆炸。基本思路是通过改变逆向传播求导时单纯的偏导连乘关系,从而避免较小的sigmoid或relu激活函数偏导连乘现象。
RNN网络unfold以后,将按时间t展开为若干个结构相同的计算单元,每个计算单元在利用当前时间的输入以外,还需要之前时间的输出。以下将展示每个计算单元的内部计算流程,假设当前的计算单元对应时间为t。
每个计算单元内由input gateforget gateoutput gate三个“闸门”结构依先后顺序构成。在每一个gate内部,相关的输入都匹配专门的权重矩阵,各个输入相加后都将匹配专门的bias向量,总体求和后需要通过专门的激活函数进行处理形成输出。 设定当期(即t期)输入为 x t x_t ,前一期输出为 o t 1 o_{t-1}

input gate

input gate实际上是类似于一个filter,即用sigmoid激活函数的激活值过滤或加权实际的input。实际的input为:
i = t a n h ( x t W i x + o t 1 W i o + b i ) i=tanh(x_t W_{i}^x+o_{t-1} W_{i}^o+b_{i})
sigmoid激活函数filter为:
I G = s i g m o i d ( x t W I G x + o t 1 W I G o + b I G ) IG=sigmoid(x_t W_{IG}^x+o_{t-1} W_{IG}^o+b_{IG})
input gate层的最终输出就是 I I I G IG 的点乘,即元素层面的对应相乘。
I o u t = i I G I_{out}=i \circ IG

inner state s t s_t

LSTM较于普通RNN网络增加了一个内部状态量 s t s_t . 记忆的控制就是通过forget gate对于 s t 1 s_{t-1} 的过滤而发挥作用。

forget gate

与input gate相同,forget gate也是一个sigmoid激活函数激活值形成的filter,用于对上一期的状态量 s t 1 s_{t-1} 进行过滤。
F G = s i g m o i d ( x t W F G x + o t 1 W F G o + b F G ) FG=sigmoid(x_t W_{FG}^x+o_{t-1} W_{FG}^o+b_{FG})
当期的状态量 s t s_t 就是input gate层的输出值与IG过滤后的上一期状态量的简单相加的结果。注意这里的操作仅为简单的相加,并没有加入权重,不存在相乘,也没有使用新的激活函数,这一步骤是消除RNN反向传播网络梯度消失或梯度爆炸的关键:
s t = s t 1 F G + I o u t s_t=s_{t-1} \circ FG + I_{out}

output gate

同之前的两个gate类似,output gate也是一个sigmoid激活函数filter,对当期的状态量 s t s_t 进行过滤。 s t s_t 在接受过滤前,先使用tanh激活函数进行区间压缩:
O G = s i g m o i d ( x t W O G x + o t 1 W O G o + b O G ) OG=sigmoid(x_t W_{OG}^x+o_{t-1} W_{OG}^o+b_{OG})
以此对压缩后的 s t s_t 进行过滤,形成最终当期计算单元的最终输出:
o t = t a n h ( s t ) O G o_t=tanh(s_t) \circ OG
o t o_t s t s_t 将可用于下一期(t+1)计算单元的内部计算。

猜你喜欢

转载自blog.csdn.net/yuanjackson/article/details/83543831