注:
torch.permute() 和 torch.transpose() 区别:
正文:
import numpy as np
a=np.arange(1,13)
print(a)
a=a.reshape(2,2,3)
print(a)
a=a.transpose(1,0,2)
print(a)
import torch as t
a=t.tensor(a,dtype=t.float32)
a=a.permute(1,0,2)
print(a)
a=np.array(a)
a=a.transpose(2,0,1)
print(a)
# a=t.tensor(a)
a=t.tensor(a).permute(2,0,1)
print(a)
a=a.permute(2,1,0)
print(a)
[ 1 2 3 4 5 6 7 8 9 10 11 12]
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]]
[[[ 1 2 3]
[ 7 8 9]]
[[ 4 5 6]
[10 11 12]]]
tensor([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]]])
[[[ 1. 4.]
[ 7. 10.]]
[[ 2. 5.]
[ 8. 11.]]
[[ 3. 6.]
[ 9. 12.]]]
tensor([[[ 1., 7.],
[ 2., 8.],
[ 3., 9.]],
[[ 4., 10.],
[ 5., 11.],
[ 6., 12.]]])
tensor([[[ 1., 4.],
[ 2., 5.],
[ 3., 6.]],
[[ 7., 10.],
[ 8., 11.],
[ 9., 12.]]])
Process finished with exit code 0