LSTM输入层要求的维度是三维的,其中包含三个参数: [batch_size, input_dim, time_step]
1.训练的话一般一批一批训练,即让batch_size 个样本同时训练;
2.每个样本又包含从该样本往后的连续seq_len个样本(如seq_len=15),seq_len也就是LSTM中cell的个数;
3.每个样本又包含inpute_dim个维度的特征(如input_dim=7)
因此,输入层的输入数据通常先要reshape:
x= np.reshape(x, (batch_size , seq_len, input_dim)) #[64, 28, 28]
(友情提示:每个cell共享参数!)
扫描二维码关注公众号,回复:
10693382 查看本文章
LSTM识别MNIST手写数字集
# RNN学习时使用的参数
learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 10
# 神经网络的参数
n_input = 28 # 输入层的n
n_steps = 28 # 28长度
n_hidden = 128 # 隐含层的特征数
n_classes = 10 # 输出的数量,因为是分类问题,0~9个数字,这里一共有10个