torch.cat是PyTorch深度学习框架中的一个函数,用于将多个张量沿着指定的维度拼接在一起。具体来说,它可以将多个形状相同的张量按照指定的维度进行拼接,返回一个新的张量。
torch.cat的语法如下:
torch.cat(seq, dim=0, *, out=None) -> Tensor
其中,参数seq是要拼接的张量序列,它们应该具有相同的形状(除了沿着拼接维度的大小),可以是一个Python列表或元组。参数dim表示要沿着哪个维度进行拼接,默认值为0(即第0个维度)。返回值是一个新的张量,与输入张量在拼接维度上的大小之和相同。
例如,假设有两个形状为(3, 4)的张量x1和x2,我们可以使用以下代码将它们沿着第0个维度进行拼接:
import torch
x1 = torch.randn(3, 4)
x2 = torch.randn(3, 4)
y = torch.cat([x1, x2], dim=0)
print(y.shape) # 输出:torch.Size([6, 4])
上述代码将输出一个形状为(6, 4)的新张量y
,其中前三行是x1
的内容,后三行是x2
的内容。