1. 张量的拼接、扩维拼接和扩维
1.1 torch.cat
在 PyTorch 中,可以使用 torch.cat
函数来进行数组拼接操作。torch.cat 可以将两个或多个张量(tensor)拼接在一起,可以指定拼接的维度。例如,假设我们有两个大小为 (3, 4) 和 (3, 5) 的张量,我们可以在第二个维度上拼接它们,代码如下:
import torch
a = torch.randn(3, 4)
b = torch.randn(3, 5)
c = torch.cat([a, b], dim=1)
print(c.shape)
输出结果为:
torch.Size([3, 9])
1.2 torch.stack
torch.stack
函数可以将多个张量沿着新创建的维度进行拼接,例如:
import torch
a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.stack([a, b], dim=0)
print(c.shape)
输出结果为:
torch.Size([2, 3, 4])
1.3 torch.unsqueeze
如果要将一个张量扩展一个新的维度,可以使用 torch.unsqueeze
函数,例如:
import torch
a = torch.randn(3, 4)
b = torch.unsqueeze(a, dim=0)
print(b.shape)
输出结果为:
torch.Size([1, 3, 4])
2 数组的拼接、扩维拼接和扩维
2.1 np.concatenate
在 NumPy 中,可以使用 np.concatenate
函数来进行数组拼接操作。np.concatenate
可以将两个或多个数组拼接在一起,可以指定拼接的维度。例如,假设我们有两个大小为 (3, 4)
和 (3, 5)
的数组,我们可以在第二个维度上拼接它们,代码如下:
import numpy as np
a = np.random.randn(3, 4)
b = np.random.randn(3, 5)
c = np.concatenate([a, b], axis=1)
print(c.shape)
输出结果为:
(3, 9)
2.2 np.stack
np.stack
函数可以将多个数组沿着新创建的维度进行拼接,例如:
import numpy as np
a = np.random.randn(3, 4)
b = np.random.randn(3, 4)
c = np.stack([a, b], axis=0)
print(c.shape)
输出结果为:
(2, 3, 4)
2.3 np.newaxis
如果要将一个数组扩展一个新的维度,可以使用 np.newaxis
函数,例如:
import numpy as np
a = np.random.randn(3, 4)
b = a[np.newaxis, ...]
print(b.shape)
输出结果为:
(1, 3, 4)