torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)
1.1 作用
根据给定的维度对输入张量进行升值或降值排序。
1.2 参数
input: 需要是一个torch.Tensor类型的张量。
dim: 给定一个张量的维度(int型),按照这个维度上的数值进行排序。如果不指定,默认按照张量的最后一个维度进行排序。
descending:传入一个布尔类型的数据(Ture、False),True代表降值排序,False代表升值排序。如果不指定,默认升值排序。
stable:传入一个布尔类型的数据(Ture、False),当一个张量中存在多个相同数字时,例如[2, 2, 1, 1],传入True不会打乱同一个数字的先后顺序(第一个1会排在第一个,第二个1会排在第二个)。如果不指定,默认False。
out:(Tensor, LongTensor) 的输出元组,可以选择用作输出缓冲区。如果不指定,默认None。
1.3 举例
先是只传入张量,其他参数均为默认:
import torch
tensor_a = torch.tensor([[2, 1],
[3, 4],
[6, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[1, 2],
[3, 4],
[5, 6]])
tensor([[1, 0],
[0, 1],
[1, 0]])
#----------------------#
dim = 0 的情况:
import torch
tensor_a = torch.tensor([[6, 1],
[1, 4],
[2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[1, 1],
[2, 4],
[6, 5]])
tensor([[1, 0],
[2, 1],
[0, 2]])
#----------------------#
descending = True 的情况:
import tensor
tensor_a = torch.tensor([[6, 1],
[1, 4],
[2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0, descending=True)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[6, 5],
[2, 4],
[1, 1]])
tensor([[0, 2],
[2, 1],
[1, 0]])
#----------------------#
stable = True 的情况:
import torch
tensor_a = torch.tensor([0, 1] * 9)
sorted_tensor_a, indices = torch.sort(tensor_a, stable=True)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])
#----------------------------------------------------------#
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])
#----------------------------------------------------------#
可以看到,在我的运行结果中,stable无论是True还是False,好像结果都是一样的,但是以下是官方教程中的例子:
同样的函数,同样的输入,为什么我的和官方的输出不一样我也不是很清楚。。。