求取指定维度的均值
import torch
x=torch.arange(15).view(3,5)
x = x.float()
print(x)
x = x.mean(dim=1,keepdim=True)
print(x)
输出:
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
tensor([[ 2.], [ 7.], [12.]])
求取指定维度的均值
import torch
x=torch.arange(15).view(3,5)
x = x.float()
print(x)
x = x.mean(dim=1,keepdim=True)
print(x)
输出:
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
tensor([[ 2.], [ 7.], [12.]])