Tensor的缩减操作
Tensor的主要运算操作通常分为四大类:
- Reshaping operations(重塑操作)
- Element-wise operations(元素操作)
- Reduction operations(缩减操作)
- Access operations(访问操作)
缩减操作
一个张量的缩减操作是一个减少张量中包含的元素数量的操作,其实质就是允许我们对单个张量中的元素执行操作。常见的缩减操作主要有:
- sum
- prod
- mean
- std
- max
- argmax
下面以sum、max和argmax操作进行说明,演示的张量如下,其他操作可以类比进行理解。
t = torch.tensor([
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]
], dtype=torch.float32)
sum操作
sum操作是将张量中的每一个元素进行累加的操作,现在我们通过累加前后的元素个数对比来体会缩减操作的含义。
(1)检测操作之前的元素个数
t.numel()
显示结果:
12
(2)执行求和操作
t.sum()
显示结果:
tensor(24.)
(3)比较执行求和操作前后的元素个数
t.sum().numel() < t.numel()
显示结果:
True
由于该操作减少了许多元素,所以求和操作是一个缩减操作。
那么,现在需要思考的是,缩减操作总是会变成一个带有单个元素的张量吗?
其实不是的,我们可以传递维度参数的值来减少特定的轴。
(4)沿着第一个轴来进行求和
t.sum(dim=0)
显示结果:
tensor([6., 6., 6., 6.])
其相加过程就是取第一个轴的所有元素的总和,代码演示如下:
print(t[0])
print(t[1])
print(t[2])
print(t[0]+t[1]+t[2])
显示结果:
tensor([1., 1., 1., 1.])
tensor([2., 2., 2., 2.])
tensor([3., 3., 3., 3.])
tensor([6., 6., 6., 6.])
(5)沿着第二个轴来进行求和
t.sum(dim=1)
显示结果:
tensor([ 4., 8., 12.])
其相加过程就是取第二个轴的所有元素的总和,代码演示如下:
print(t[0].sum())
print(t[1].sum())
print(t[2].sum())
显示结果:
tensor(4.)
tensor(8.)
tensor(12.)
argmax操作
简化操作,其作用是返回张量中元素最大值的索引位置,当我们在一个张量上调用argmax方法时,这个张量被缩减为一个新的张量,它包含一个单独的索引值,指示着张量里面的最大值。
现假设存在以下张量:
t = torch.tensor([
[1, 0, 0, 2],
[0, 3, 3, 0],
[4, 0, 0, 5]
], dtype=torch.float32)
(1)处理整个张量
print(t.max())
print(t.argmax())
显示结果:
tensor(5.)
tensor(11)
可以看到 t.argmax() 返回的结果是11,其实这个11是先将张量进行flatten操作后再取索引的结果:
t.flatten()
显示结果:
tensor([1., 0., 0., 2., 0., 3., 3., 0., 4., 0., 0., 5.])
(2)处理特定的轴
指定第一个轴
print(t.max(dim=0))
print(t.argmax(dim=0))
显示结果:
torch.return_types.max(
values=tensor([4., 3., 3., 5.]),
indices=tensor([2, 1, 1, 2]))
tensor([2, 1, 1, 2])
指定第二个轴
print(t.max(dim=1))
print(t.argmax(dim=1))
显示结果:
torch.return_types.max(
values=tensor([2., 3., 5.]),
indices=tensor([3, 1, 3]))
tensor([3, 1, 3])