output_list = [] # 存放输出的list
for idlayer in range(self.num_layers): # 每一次循环
current_input = input[idlayer,...]
current_output,current_hidden_state = self.cell(current_input,current_hidden_state)
# 按照输出间隔将输出存储下来
if idlayer % self.out_stride == self.out_stride - 1:
output_list.append(current_output)
output = torch.stack(output_list, dim=0) # 将输出合并成一个tensor
先用list.append存储下来
再用torch.stack接受list,并定义需要拼接的新维度即可。
current_output的纬度为[10,4,64,64]
n个current_output拼接的output的纬度为[n,10,4,64,64]