pytorch中的神经网络子模块(线性模块)——torch.nn.Linear

pytorch中的线性模块的实现如下,在init函数中定义weight值和bias值。

class Linear(Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

所以若要对linear子模块的参数进行初始化,利用如下策略可以对单个linear子模块进行参数初始化。

import torch.nn as nn
from torch.nn import init
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
          ('linear', nn.Linear(num_inputs, 1))
        ]))

print(net )
print(net[0])

init.normal_(net[0].weight, mean=0.0, std=0.01)
init.constant_(net[0].bias, val=0.0)  # 也可以直接修改bias的data: net[0].bias.data.fill_(0)

#----------------
LinearNet(
  (linear): Linear(in_features=2, out_features=1, bias=True)
)
Linear(in_features=2, out_features=1, bias=True)
<class 'torch.nn.modules.linear.Linear'>
发布了233 篇原创文章 · 获赞 187 · 访问量 40万+

猜你喜欢

转载自blog.csdn.net/qiu931110/article/details/104292129