Pytorch通过.unsqueeze(int)
方法来增加1个维度,传的int值为增加的维度的索引。下面通过程序来说明其用法。
-
生成测试数据
import torch t1 = torch.tensor([1,2,3])
-
进行维度增加
print(t1.unsqueeze(0)) # tensor([[1, 2, 3]]) print(t1.unsqueeze(1)) # tensor([[1], # [2], # [3]]) print(t1.unsqueeze(2)) # 报错 # IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
-
从上面可能不容易观察传入索引的作用,观察一下增加后的size()
print(t1.size()) print(t1.unsqueeze(0).size()) print(t1.unsqueeze(1).size()) # torch.Size([3]) # torch.Size([1, 3]) # torch.Size([3, 1])
当前维度是1,增加1维后维度变成2,索引值也就只能由两个(0和1),所以上面以2为索引时报错了
-
用二维数据进行测试
t2 = torch.tensor([[1,2,3],[4,5,6]]) print(t2.size()) # torch.Size([2, 3]) print(t2.unsqueeze(0).size()) # torch.Size([1, 2, 3]) print(t2.unsqueeze(1).size()) # torch.Size([2, 1, 3]) print(t2.unsqueeze(2).size()) # torch.Size([2, 3, 1]) print(t2.unsqueeze(3).size()) # IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
2+1=3,索引可以是0,1,2,当索引时3的时候报错
-
我想看看增加到三维是个什么样的
print(t2.unsqueeze(0)) # tensor([[[1, 2, 3], # [4, 5, 6]]]) print(t2.unsqueeze(1)) # tensor([[[1, 2, 3]], # # [[4, 5, 6]]])