可以用torch.cat方法和torch.stack方法将多个张量合并,可以用torch.split方法把一个张量分割成多个张量。
torch.cat和torch.stack有略微的区别,torch.cat是连接,不会增加维度,而torch.stack是堆叠,会增加维度。
举例如下:
cat
torch.cat(tensors,dim=0,out=None)→ Tensor
A = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("A shape: {}" .format(A.shape))
B = torch.tensor([[12, 22, 33],
[44, 55, 66],
[77, 88,99]])
print("B shape: {}" .format(B.shape))
A shape: torch.Size([3, 3])
B shape: torch.Size([3, 3])
dim = 0
按照维度0进行拼接,不新增维度,(3, 3)和(3, 3) cat后维度为:(6, 3)
result1 = torch.cat((A, B), 0)
print("result1 shape: {}".format(result1.shape))
print(result1)
result1 shape: torch.Size([6, 3])
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[12, 22, 33],
[44, 55, 66],
[77, 88, 99]])
dim = 1
按照维度1进行拼接,不新增维度,(3, 3)和(3, 3) cat后维度为:(3, 6)
result2 = torch.cat((A, B), 1)
print("result2 shape: {}".format(result2.shape))
print(result2)
result2 shape: torch.Size([3, 6])
tensor([[ 1, 2, 3, 12, 22, 33],
[ 4, 5, 6, 44, 55, 66],
[ 7, 8, 9, 77, 88, 99]])
stack
dim=0
按照维度0进行拼接,会新增一个维度,(3, 3)和(3, 3) stack后维度为:(2, 3, 3)
result3 = torch.stack((A, B), dim=0)
print("result3 shape: {}".format(result3.shape))
print(result3)
result3 shape: torch.Size([2, 3, 3])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[12, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
dim=1
按照维度1进行拼接,会新增一个维度,(3, 3)和(3, 3) stack后维度为:(3, 2, 3)
result4 = torch.stack((A, B), dim=1)
print("result4 shape: {}".format(result4.shape))
print(result4)
result4 shape: torch.Size([3, 2, 3])
tensor([[[ 1, 2, 3],
[12, 22, 33]],
[[ 4, 5, 6],
[44, 55, 66]],
[[ 7, 8, 9],
[77, 88, 99]]])
dim=2
按照维度2进行拼接,因为会新增一个维度,所以2不会索引越界,(3, 3)和(3, 3) stack后维度为:(3, 3, 2)
result5 = torch.stack((A, B), dim=2)
print("result5 shape: {}".format(result5.shape))
print(result5)
result5 shape: torch.Size([3, 3, 2])
tensor([[[ 1, 12],
[ 2, 22],
[ 3, 33]],
[[ 4, 44],
[ 5, 55],
[ 6, 66]],
[[ 7, 77],
[ 8, 88],
[ 9, 99]]])
split
q, k, v = torch.split(result3, split_size_or_sections=1, dim=1)
print(q, q.shape)
print(k, k.shape)
print(v, v.shape)
print()
print(torch.stack((q, k, v), 1), torch.stack((q, k, v), 1).shape)
print(torch.cat((q, k, v), 1), torch.cat((q, k, v), 1).shape)
tensor([[[ 1, 2, 3]],
[[12, 22, 33]]]) torch.Size([2, 1, 3])
tensor([[[ 4, 5, 6]],
[[44, 55, 66]]]) torch.Size([2, 1, 3])
tensor([[[ 7, 8, 9]],
[[77, 88, 99]]]) torch.Size([2, 1, 3])
tensor([[[[ 1, 2, 3]],
[[ 4, 5, 6]],
[[ 7, 8, 9]]],
[[[12, 22, 33]],
[[44, 55, 66]],
[[77, 88, 99]]]]) torch.Size([2, 3, 1, 3])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[12, 22, 33],
[44, 55, 66],
[77, 88, 99]]]) torch.Size([2, 3, 3])
q1, q2 = torch.split(result3, split_size_or_sections=[2, 1], dim=1)
print(q1, q1.shape)
print(q2, q2.shape)
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[12, 22, 33],
[44, 55, 66]]]) torch.Size([2, 2, 3])
tensor([[[ 7, 8, 9]],
[[77, 88, 99]]]) torch.Size([2, 1, 3])