筛选排除还没找到答案:
取数运算
正好遇到一个需求。
我有m行k列的一个表a,和一个长为m的索引列表b。
b中储存着,取每行第几列的元素。
这种情况下,你用普通的索引是会失效的。
import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1])
错误写法:
c= a[b]
print(c)
结果是第1行和第2行
方法1:
import torch
conf_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
b = torch.LongTensor([0, 1])
index_num = torch.arange(0, conf_data.size(0))
print(conf_data[index_num,b])
经过一番查找,发现我们可以用神奇的torch.gather()函数
import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1]).view(2,1)
c= torch.gather(input=a,dim=1,index=b)
print(c)
#tensor([[1],
[5]])