pytorch RNN&LSTM

最近几天把花书上RNN一节给看完了。然后到网上搜了一个LSTM的例子,依葫芦画瓢跑了一下,对RNN&LSTM应该说也是有了一个大致的了解,今天就来记录一下。

RNN介绍

RNN主要是被设计出来用来对序列数据进行建模的。
在这里插入图片描述
上图是经典的RNN的模型,对模型进行一个输入x,然后在隐藏单元同时有一个x输入和 h t 1 h_{t-1} 分别和权重相乘再激活函数作用下就得到了 h t h_t h t h_t 在权重相乘激活就得到了输出 y t y_t ,进而可以计算一下Loss function了。公式如下图所示:
在这里插入图片描述
这就是RNN原理的全部了,RNN在传播时是从序列的开始向前传播,一直传播到最后一个时刻,反向传播时就是从最后一个时刻一个时间步一个时间步向后传播回去。
当然了,RNN还存在很多变种的,比如导师驱动过程和输出循环网络等等,不过都大同小异,上面展示的这个RNN应当是最基础,也是最普遍的一个。

LSTM

LSTM是在RNN基础上进行改进的一个模型,主要是加入了三个门,输入门、输出门、遗忘门。输入门的作用是控制输入,也就是上面提到的x,实际上就是学习一个权重去和x相乘,控制在本个时刻我需要传进去多少x,遗忘门的作用就是控制上一个时刻的状态 h t 1 h_{t-1} 有多少我还需要保留着,需要继续沿着时间往后传播,其实也就是学习一个权重和 h t 1 h_{t-1} 相乘。输出门的作用是控制有多少输出需要传出去,其实也就是学习一个权重和输出相乘,然后output。

上面是一种通俗的解释,基本的道理是没错的,不过在一些细节方面还是有点不大对,因为LSTM和RNN的结构还是有一点小区别的。详细的LSTM单元结构可以看这个博文的解释。

LSTM进行手写数字识别

之前我写过CNN的模型进行手写数字识别,今天在网上看到也可以用LSTM进行手写数字识别,觉得有意思,顺带刚好练习一下用pytorch进行LSTM模型的搭建,参考文章是这篇。

Pytorch LSTM模型参数解释

在这里插入图片描述
关于这个例子中的模型的一些更详细的解释可以看我的notebook代码及说明,然后关于LSTM参数的解释可以看一下下面这两篇博文:
https://blog.csdn.net/yangyang_yangqi/article/details/84585998
https://zhuanlan.zhihu.com/p/39191116
看完这两篇博文和我写的notebook(最后自己跑一下)基本能够有一个大致的理解了。

发布了41 篇原创文章 · 获赞 40 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/qq_39805362/article/details/104206344
今日推荐