torch.index_select(self: Tensor, dim: Union[str, None], index: Tensor)
:第一个Tensor是被操作的tensor,dim表示要操作的维度,index是在这个维度下提取的索引张量
例子:2维张量
import torch
a = torch.linspace(1, 12, steps=12).view((4, 3))
print(a)
>>> tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]])
print(torch.index_select(a, 0, torch.tensor([0, 3]))) # 提取第0个维度中的0和3
>>> tensor([[ 1., 2., 3.],
[10., 11., 12.]])
例子:3维张量
import torch
a = torch.linspace(1, 12, steps=12).view((2, 2, 3))
print(a)
>>> tensor([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]]])
print(torch.index_select(a, 2, torch.tensor([0, 2]))) # 提取第2个dim的0和2
>>> tensor([[[ 1., 3.],
[ 4., 6.]],
[[ 7., 9.],
[10., 12.]]])