rnn(penn tree bank)
rnn的一个典型应用是用于处理自然语言,ptb是一个常用的数据集,里面包含上百万的词汇。本例就是采用lstm对ptb进行自然语言预测。
一些参数
batch_size = 20
num_steps = 20 # lstm在时间轴上的截断长度
hidden_size = 200 # lstm隐藏参数的大小
num_layers = 3 # lstm网络的深度
- 1
- 2
- 3
- 4
- 5
(1)数据预处理过程
原始数据是编码后的文本,用一维向量表示;将该原始数据reshape成batch size宽度的数据,以提高数据处理效率;在长度方向每隔num steps长度截断一次,构成网络输入x;将x右移一个位置构成标签y’。(这里的y’的每个元素都刚好为x的同位置元素的下一个词汇,因为rnn模型主要用于词汇的预测)
(2)网络流程
input: x(20,20), target: y’(20, 20) → (400)
1) after embedding: input → (20, 20, 200)
2) lstm output: 20 x (20, 200) → (400, 200)
3) softmax: output → (400, 10000)
4) loss between target and output
5) evaluate: perplexity
注意细节
每处理一个batch进行一次参数更新,每个batch的截断长度为20;
rnn处理完一个batch后保存每一层最后一个时刻的状态,作为下个batch状态的初值