版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/84234352
话不多说,上代码,上面写的很清楚。
import torch.nn as nn
import torch
net= nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 6),
)
net[4].weight.data=torch.zeros(6,256)
net[4].bias.data=torch.ones(6)
t=torch.randn(32,1024)
print(net(t).size())
cnn=nn.Sequential(
nn.Conv2d(2,8,3,1,1),
nn.Conv2d(8,19,3,1,1)
)
cnn[0].weight.data=torch.randn(8*2*3*3).view(8,2,3,3)
cnn[0].bias.data=torch.ones(8)
t=torch.randn(32,2,100,100)
print(cnn(t).size())
注:
- 注意线性层和卷积层输入通道和输出通道的关系,初始化的时候要是转置的形式。
- 后面生成一个tensor送入网络是为了测试初始化的正确性
- 出错的话请参考:The expanded size of the tensor (256) must match the existing size (81) at non-singleton dimension1
- 虽然代码中写的是sequential,对于module同样是可以用的