pytorch tensor 筛选排除

筛选排除还没找到答案:

取数运算
 

正好遇到一个需求。

我有m行k列的一个表a,和一个长为m的索引列表b。

b中储存着,取每行第几列的元素。

这种情况下,你用普通的索引是会失效的。

import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1])
 

错误写法:
c= a[b]
print(c)
 结果是第1行和第2行
 

方法1:

    import torch

    conf_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
    b = torch.LongTensor([0, 1])

    index_num = torch.arange(0, conf_data.size(0))

    print(conf_data[index_num,b])


 

经过一番查找,发现我们可以用神奇的torch.gather()函数

import torch
 
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1]).view(2,1)
 
c= torch.gather(input=a,dim=1,index=b)
print(c)
#tensor([[1],
        [5]])
 
 

发布了2732 篇原创文章 · 获赞 1011 · 访问量 538万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/104764740