torch.max()
pytorch文档中提到:该函数返回一个元组:(值,索引),其中值是给定维度dim中输入张量每行的最大值。索引是找到的每个最大值(argmax)的索引位置。
如果keepdim为True,则输出张量的大小与输入相同,但维度dim中的大小为1。否则dim被压缩,导致输出张量的维数比输入少1。
注:若有多个最大值,则返回第一个最大值的索引
代码演示
a = torch.randn(4, 4)
print(a)
#tensor([[-0.7670, -0.2193, 0.1777, 0.3602],
# [ 1.0125, 0.8830, -1.1294, -1.8622],
# [ 1.3611, 1.2073, 1.8415, -1.4175],
# [-0.7687, 0.6015, 0.1030, -0.1119]])
a1 = torch.max(a) # 所有元素中最大的
print(a1)
#tensor(1.8415)
a2 = torch.max(a, 0) # 返回每一列的最大值,及其索引
print(a2)
#torch.return_types.max(
#values=tensor([1.3611, 1.2073, 1.8415, 0.3602]),
#indices=tensor([2, 2, 2, 0]))
a3 = torch.max(a, 1) # 返回每一行的最大值,及其索引
print(a3)
#torch.return_types.max(
#values=tensor([0.3602, 1.0125, 1.8415, 0.6015]),
#indices=tensor([3, 0, 2, 1]))
a4 = torch.max(a, 1)[0] # 只返回最大值
print(a4)
#tensor([0.3602, 1.0125, 1.8415, 0.6015])
a5 = torch.max(a, 1)[1] # 只返回最大值索引
print(a5)
#tensor([3, 0, 2, 1])
a6 = torch.max(a, 1)[1].numpy() # 将结果转化为Numpy格式
print(a6)
#[3 0 2 1]