torch.cat(seq,dim,out=None)
其中seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列
dim 表示以哪个维度连接,dim=0, 横向连接,dim=1,纵向连接
举例如下:
import torch
a = torch.ones([1, 2])
b = torch.ones([1, 2])
print(torch.cat([a, b], 1)) # dim=1纵向连接
print(torch.cat([a, b], 0)) # dim=0横向连接
输出结果:
tensor([[1., 1., 1., 1.]])
tensor([[1., 1.],
[1., 1.]])
纵向连接之后,维度变成1*4
横向连接之后,维度变成2*2