函数torch.
gather
(input, dim, index, out=None, sparse_grad=False) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
- input (Tensor) – 原张量
- dim (int) – 索引的轴,二维中dim=0代表以每一列为独立个体,对列中元素进行索引排序,dim=1代表以每一行为独立个体,对行中元素进行索引排序。
- index (LongTensor) – 索引
- out (Tensor, optional) – 目标张量
- sparse_grad (bool,optional) – If
True
, gradient w.r.t.input
will be a sparse tensor.(没用过)
b = torch.Tensor([[1,2,3],[4,5,6]]) print(b) index_1 = torch.LongTensor([[0,1,0,1],[2,0,2,0]]) index_2 = torch.LongTensor([[0,1,1],[0,0,0],[1,1,1]]) print (torch.gather(b, dim=1, index=index_1))#以每一行为独立个体,对行中元素进行索引排序,所以索引表index的行数需要等于原矩阵的行数,对每列中的每个元素进行编号。 print (torch.gather(b, dim=0, index=index_2))#以每一列为独立个体,对列中元素进行索引排序,所以索引表index的列数需要等于原矩阵的列数,对每行中的每个元素进行编号。
上述输出为:
tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[1., 2., 1., 2.], [6., 4., 6., 4.]]) tensor([[1., 5., 6.], [1., 2., 3.], [4., 5., 6.]])
官方文档,三维举例:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Example:
>>> t = torch.tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]])) tensor([[ 1, 1], [ 4, 3]])
三维情况说明:
三维中: dim=0代表以每一个小channel为独立个体(一共有行x列个),对每个channel中的元素进行索引排序。 index可以有很多个channel但是行数、列数需要等于原矩阵的函数列数,否则,会超出索引范围。
dim=1代表以每一列为独立个体,对列中元素进行索引排序,
dim=2代表以每一行为独立个体,对行中元素进行索引排序。
a = torch.randint(0, 30, (2, 3, 5)) print(a) index = torch.LongTensor([[[0,1,2,0,2], [0,0,0,0,0], [1,1,1,1,1]], [[1,2,2,2,2], [0,0,0,0,0], [2,2,2,2,2]]]) index2 = torch.LongTensor([[[0,1,1,0,1], [0,1,1,1,1], [1,1,1,1,1]], [[1,0,0,0,0], [0,0,0,0,0], [1,1,0,0,0]], [[1,0,0,0,0], [0,0,0,0,0], [1,1,0,0,0]]]) # dim=0 b = torch.gather(a,0,index2) print("dim=0:\n",b) #dim=1 c = torch.gather(a,1,index) print("dim=1:\n",c) #dim=2 d = torch.gather(a,2,index) print("dim=2:\n",d)
输出:
tensor([[[26, 5, 16, 8, 8], [22, 7, 9, 27, 12], [25, 10, 7, 6, 4]], [[ 4, 11, 2, 2, 2], [12, 0, 21, 13, 7], [ 2, 20, 13, 26, 2]]]) dim=0: tensor([[[26, 11, 2, 8, 2], [22, 0, 21, 13, 7], [ 2, 20, 13, 26, 2]], [[ 4, 5, 16, 8, 8], [22, 7, 9, 27, 12], [ 2, 20, 7, 6, 4]], [[ 4, 5, 16, 8, 8], [22, 7, 9, 27, 12], [ 2, 20, 7, 6, 4]]]) dim=1: tensor([[[26, 7, 7, 8, 4], [26, 5, 16, 8, 8], [22, 7, 9, 27, 12]], [[12, 20, 13, 26, 2], [ 4, 11, 2, 2, 2], [ 2, 20, 13, 26, 2]]]) dim=2: tensor([[[26, 5, 16, 26, 16], [22, 22, 22, 22, 22], [10, 10, 10, 10, 10]], [[11, 2, 2, 2, 2], [12, 12, 12, 12, 12], [13, 13, 13, 13, 13]]])
dim = 0的时候(三维情况下),举的例子只有2 channels。所以index在0,1两个之间选择。 输出的矩阵元素也是按照index的指定。分别在1st channel和2nd channel之间选。 index [0,1,1,0,1]的分别代表第一个元素在1st channel选,第二个元素在2nd channel选,第三个元素在2nd channel选,第四个元素在1st channel选,第五个元素在2nd channel选。