参考自:https://blog.csdn.net/m0_45478865/article/details/104455978
nn.LSTM :
# (input_dim, output_dim, layers, )
lstm = nn.LSTM(768, 768, 1, bidirectional=True, batch_first=True)
# (batch_size, seq_len, hidden_size)
x = torch.randn(2, 512, 768)
# lstm有三个输出
result = lstm(x)
len(result) # 2 : (output, (hn, cn))
# output代表LSTM最后一层的输出h_t,
# h_n代表t=seq_len最后一个神经元的隐藏状态
# c_n代表t=seq_len最后一个神经元的单元状态
output, (hn, cn) = result
output.shape # torch.Size([2, 512, 1536]) (batch, seq_len, num_directions * hidden_size)
hn.shape # torch.Size([2, 2, 768]) (num_layers * num_directions, batch, hidden_size)
cn.shape # torch.Size([2, 2, 768]) (num_layers * num_directions, batch, hidden_sizee)
nn.GRU :
# (input_size, hidden_size, layers, )
gru = nn.GRU(768, 768, 1, batch_first=True)
# (batch_size, seq_len, hidden_size)
x = torch.randn(2, 512, 768)
# GRU有两个输出,(output, hn)
result = gru(x)
len(result) # 2
# output 代表GRU最后一层的特征输出h_t
# hn 代表t=seq_len最后一个神经元的隐藏状态
output, hn = result
output.shape # torch.Size([2, 512, 768]) (batch, seq_len, num_directions * hidden_size)
hn.shape # torch.Size([1, 2, 768]) (num_layers * num_directions, batch, hidden_size)