版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/81875508
#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
import torch
a=torch.rand(2,3,1)
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1])
print(a.size()) #torch.Size([2, 3, 1])
print(a.squeeze().size()) #torch.Size([2, 3])
print(a.squeeze(0).size()) #torch.Size([2, 3, 1])
print(a.squeeze(-1).size()) #torch.Size([2, 3])
print(a.size()) #torch.Size([2, 3, 1])
print(a.squeeze(-2).size()) #torch.Size([2, 3, 1])
print(a.squeeze(-3).size()) #torch.Size([2, 3, 1])
print(a.squeeze(1).size()) #torch.Size([2, 3, 1])
print(a.squeeze(2).size()) #torch.Size([2, 3])
print(a.squeeze(3).size()) #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
print(a.unsqueeze().size()) #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size()) #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size()) #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size()) #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size()) #torch.Size([2, 3])