# -*- encoding: utf-8 -*-
'''
@Author: Xiaosu Wang
@Email: [email protected]
@Version : 1.0
@File : rnn.py
@Time : 2020-01-30 22:44
@Description :
Pytorch中RNN相关源码在文件:torch/nn/modules/rnn.py
'''
import torch
import torch.nn as nn
init_seed = 2020
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
# np.random.seed(init_seed) # 用于numpy的随机数
def print_parameters(module):
for name, params in module.named_parameters():
print(name)
print(params)
'''
input_size: 输入特征维度,即词向量维度
hidden_size: 隐藏特征维度
num_layers: 层数。 Default: 1
nonlinearity: 非线性函数,'tanh' or 'relu'. Default: 'tanh'
bias: 是否加偏移。 Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
as `(batch, seq, feature)`. Default: ``False`` (seq_len, batch, input_size)
dropout: 非0表示添加dropout层. Default: 0
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
'''
rnn = nn.RNN(4, 5, 1) # RNN : input_size, hidden_size, num_layers
rnn_cell = nn.RNNCell(4, 5) # RNNCell : input_size, hidden_size
input = torch.randn(3, 2, 4) # (seq_len, batch, input_size)
h0 = torch.randn(1, 2, 5) # (num_layers * num_directions, batch, hidden_size)
# 计算RNN
output, hn = rnn(input, h0)
# 用 RNNCell 模拟 RNN
hx = h0[0]
output_cell = []
for i in range(3):
hx = rnn_cell(input[i], hx)
output_cell.append(hx)
print('----' * 5 + '两者结果不同' + '----' * 5 )
print(output)
print(output_cell)
print('----' * 5 + '观察 RNN 、RNNCell 的参数' + '----' * 5 )
print_parameters(rnn)
print('----' * 5)
print_parameters(rnn_cell)
print('----' * 5 + '将 RNN 的参数赋值给 RNNCell,使两者 Cell 的参数一样' + '----' * 5 )
rnn_cell.weight_ih = rnn.weight_ih_l0
rnn_cell.weight_hh = rnn.weight_hh_l0
rnn_cell.bias_ih = rnn.bias_ih_l0
rnn_cell.bias_hh = rnn.bias_hh_l0
print('----' * 5 + '观察 RNN 、RNNCell 的参数' + '----' * 5 )
print_parameters(rnn)
print('----' * 5)
print_parameters(rnn_cell)
print('----' * 5 + '重新用 RNNCell 模拟 RNN' + '----' * 5 )
output_cell = []
hx = h0[0]
for i in range(3):
hx = rnn_cell(input[i], hx)
output_cell.append(hx)
print('----' * 5 + '两者结果相同' + '----' * 5 )
print(output)
print(output_cell)
print('----' * 5 + '多层、双向都是相同的道理' + '----' * 5 )
RNN、RNNCell
猜你喜欢
转载自blog.csdn.net/wangxiaosu/article/details/104120393
今日推荐
周排行