循环神经网络:RNNCell的定义

import torch

batch_size=1
sep_len=3    #一个样本列中,所含独立样本x的个数
input_size=4
hidden_size=2

cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)

#(seq,batch,feature)
dataset=torch.randn(sep_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)#定义h0

for idx,input in enumerate(dataset):  #一个batch_size的x1+h0=h1 x2+h1=h2 以此迭代
    print('='*20,idx,'='*20)
    print('input_size',input.shape)

    hidden=cell(input,hidden)

    print('output_size', hidden.shape)
    print(hidden)

猜你喜欢

转载自blog.csdn.net/qq_21686871/article/details/114407865