[PyTorch] rnn,lstm,gru中输入输出维度 - 简书
LSTM神经网络输入输出究竟是怎样的? - 知乎
pytorch文档
可以把上面的每一列看成是有厚度的(如 128 维).
为什么输入可以是任意长度?
以 RNN 为例, 使用的网络是同一个网络, 只是在不同的时间步接收输入而已.
即 RNN 在一个时间步只接收一个输入 . 所以我们在下面 pytorch 设置 LSTM 网络参数的时候不用设置 time_step, 只需要设置输入向量的维度就可以.
注意:不要被展开图给迷惑了.
设置网络参数
torch.nn.LSTM( input_size, hidden_size, num_layers )
输入特征的维度 ‘num_units’
接收输入
Inputs: input, (h_0, c_0)
‘三维’‘三维’‘三维’
- input of shape (seq_len, batch, input_size) — batch 指 一个 batch 所含的序列个数, LSTM 一次处理一个 batch 的所有序列
- h_0 of shape (num_layers * num_directions, batch, hidden_size)
- c_0 of shape (num_layers * num_directions, batch, hidden_size)
输出
Outputs: output, (h_n, c_n)
- output of shape (seq_len, batch, num_directions * hidden_size)
- h_n of shape (num_layers * num_directions, batch, hidden_size)
- c_n of shape (num_layers * num_directions, batch, hidden_size)
Example
>>> rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) # 它才不管你的输入长度, 任意长度都可以
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> print(output.size())
torch.Size([5, 3, 20])
表示输出 3 个 batch, 每个 batch 形状为 [5, 20]