今天在学习TTSR的过程总遇到了一行代码,我发现max()函数竟然可以返回两个值,于是我决定重新学习一下这个函数
R_lv3_star, R_lv3_star_arg = torch.max(R_lv3, dim=1) #[N, H*W] hi
1、基础用法:
首先是 torch.max()的基础用法,输入一个张量,返回一个确定的最大值
torch.
max
(input) → Tensor
Example:
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763, 0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)
2、深度用法:
torch.
max
(input, dim, keepdim=False, *, out=None)
按维度dim 返回最大值,并且返回索引。
Parameters
-
input (Tensor) – the input tensor.
-
dim (int) – the dimension to reduce.
-
keepdim (bool) – whether the output tensor has
dim
retained or not. Default:False
.
Keyword Arguments
out (tuple, optional) – the result tuple of two output tensors (max, max_indices),返回的最大值和索引各是一个tensor,分别表示该维度的最大值,以及该维度最大值的索引,一起构成元组(Tensor, LongTensor)
Example:
torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)
a = torch.randn(4, 4)
print(a)
print(torch.max(a,0))
tensor([[ 0.7439, 2.2739, -2.7576, -0.0676],
[-0.7755, -0.6696, 0.3009, -1.4939],
[-0.9244, 2.7325, 1.7982, 1.2904],
[-0.9091, -0.1857, -1.3392, -1.2928]])
torch.return_types.max(
values=tensor([0.7439, 2.7325, 1.7982, 1.2904]),
indices=tensor([0, 2, 2, 2]))
torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))