【Pytorch--代码技巧】各种论文代码常见技巧

博主在阅读论文源码时,常看到一些novel代码技巧,特此汇总

1 torch.view()

作用:重置Tensor对象维度

注意点:参数中的-1表示系统自动判断,因此每个view里面只能出现一个-1

x = torch.randn(4,4)
# 重置为向量
x.view(16).size()
# 重置为多维矩阵
x.view(2,2,4).size()
# -1 数字用法
x.view(-1,8).sieze
torch.Size([16])
torch.Size([2,2,4])
torch.Size([2,8])

2  torch.unsqueeze()

作用:升维,最常见的就是unsqueeze(-1)表示将一维升到二维

x = torch.randn(4)
x = x.unsqueeze(-1)
torch.Size([16,1])

3 torch.expand()

作用:升维

# x 维度为[4]
x =torch.tensor([1,2,3,4])

# x1 维度为[3,1,4]
x1 = x.expend(3,1,4)
print(x1)

>>> 
tensor([[[1, 2, 3, 4]],
        [[1, 2, 3, 4]],
        [[1, 2, 3, 4]]])
torch.Size([16,1])

4 torch.transpose(0, 1)

作用:转置

注意:
1 只能拓展维度,比如 A的shape为 2x4的,不能 A.expend(1,4),只能保证原结构不变,在前面增维,比如A.shape(1,1,4)
2 可以增加多维,比如x的shape为(4),x.expend(2,2,1,4)只需保证本身是4
3 不能拓展低维,比如x的shape为(4),不能x.expend(4,2)

x = torch.randn(16,1)
x = x.transpose(0,1)
torch.Size([1,16])

5 去除对角线元素

一般使用z方法,因为x方法对float不可用而z可以 

x = torch.randint(1, 4, (4, 4))
y = x ^ torch.diag_embed(torch.diag(x))
z = x - torch.diag_embed(torch.diag(x))
tensor([[2, 2, 2, 1],
        [3, 1, 1, 2],
        [3, 1, 3, 1],
        [3, 1, 2, 2]])
tensor([[0, 2, 2, 1],
        [3, 0, 1, 2],
        [3, 1, 0, 1],
        [3, 1, 2, 0]])
tensor([[0, 2, 2, 1],
        [3, 0, 1, 2],
        [3, 1, 0, 1],
        [3, 1, 2, 0]])

6 torch.gather()

作用:根据维度dim按照索引列表index从input中选取指定元素

b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 3.]])

7 nn.Parameter()

作用:将一个不可训练的tensor转换为一个有梯度的可训练的tensor。往往用在需要自己定义的bisa中

self.bias = nn.Parameter(torch.ones([5]))
output = self.Weight(x) + self.bias

8 nn.Sequential

作用:将多个网络模块组合成一个模块,需要注意的是相邻的两个网络模块之间的输入输出尺寸

        self.block = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512 * 2, 128),
            nn.Linear(128, 16),
            nn.Linear(16, num_classes),
            nn.Softmax(dim=1)
        )

9 nn.moduleList

作用:将多个网络模块放到一个类List中,后续可以从中进行调用

# 定义三个输入channels为1,输出channels为2,卷积核为[2,768]、[3,768]、[4,768]的卷积
        self.convs = nn.ModuleList(
            [nn.Conv2d(in_channels=1, out_channels=self.num_filters,
                       kernel_size=(k, 768), ) for k in self.filter_sizes])

10 nn.MaxPool2d

作用:MaxPooling,提取重要信息,去掉不重要信息,从而减少计算开销

 其一共有6个基本参数

  • kernel_size:池化窗口大小,输入单值如3则为3×3,输入元组如(3,2)则为3×2
  • stride:步长,单值元组均可。默认与池化窗口大小一致
  • padding:填充,单值元组均可。默认为0
  • dilation:控制窗口中元素步幅,不重要!
  • return_indices:布尔类型,返回最大值位置索引
  • ceil_mode:布尔类型,默认为False。False为向下取整,True为向上取整

11 permute()

作用:根据位置进行维度转化

>>> x = torch.randn(2, 3, 5) 
>>> x.size() 
torch.Size([2, 3, 5]) 
>>> x.permute(2, 0, 1).size() 
torch.Size([5, 2, 3])

12 torch.cat(inputs,dim)

作用:根据维度进行Tensor拼接

b = torch.cat(a, 4)

13 torch.clamp(input, min, max, out=None)

作用:将维度限制在min和max之间

sum_mask = torch.clamp(sum_mask, min=1e-9)

猜你喜欢

转载自blog.csdn.net/ccaoshangfei/article/details/127025349