ModuleList 与 ModuleDict
1、ModuleList
1)ModuleList
接收一个子模块的列表作为输入,然后也可以类似 List 那样进行 append 和 extend 操作:
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1]) # 可使用类似List的索引访问
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
# 输出:
# Linear(in_features=256, out_features=10, bias=True)
# ModuleList(
# (0): Linear(in_features=784, out_features=256, bias=True)
# (1): ReLU()
# (2): Linear(in_features=256, out_features=10, bias=True)
# )
\quad
2)nn.Sequential
和 nn.ModuleList
二者的区别:
nn.ModuleList
仅仅是一个储存各种模块的列表,这些模块之间没有联系(所以不用保证相邻层的输入输出维度匹配), 而nn.Sequential
内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配nn.ModuleList
没有实现 forward 功能需要自己实现,所以上面执行net(torch.zeros(1, 784))
会报 NotImplementedError;而nn.Sequential
内部 forward 功能已经实现。
ModuleList 的出现只是让网络定义前向传播时更加灵活,见下面官网的例子:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
\quad
3)另外,nn.ModuleList
不同于一般的 Python 的 list,加入到 nn.ModuleList
里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。
import torch
import torch.nn as nn
class Module_ModuleList(nn.Module):
def __init__(self):
super(Module_ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10)])
class Module_List(nn.Module):
def __init__(self):
super(Module_List, self).__init__()
self.linears = [nn.Linear(10, 10)]
net1 = Module_ModuleList()
net2 = Module_List()
print(net1)
for p in net1.parameters():
print(p.size())
print('*'*20)
print(net2)
for p in net2.parameters():
print(p)
输出
Module_ModuleList(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
)
)
torch.Size([10, 10])
torch.Size([10])
********************
Module_List()
2、ModuleDict
ModuleDict接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
# 输出:
# Linear(in_features=784, out_features=256, bias=True)
# Linear(in_features=256, out_features=10, bias=True)
# ModuleDict(
# (act): ReLU()
# (linear): Linear(in_features=784, out_features=256, bias=True)
# (output): Linear(in_features=256, out_features=10, bias=True)
# )
(1)和 nn.ModuleList
一样,nn.ModuleDict
实例仅仅是存放了一些模块的字典,并没有定义 forward函数 需要自己定义。
(2)同样,nn.ModuleDict
也与 Python 的 Dict 有所不同,nn.ModuleDict
里的所有模块的参数会被自动添加到整个网络中。
3、总结
Sequential
、ModuleList
、ModuleDict
类都继承自 Module 类。- 与
Sequential
不同,ModuleList
和ModuleDict
并没有定义一个完整的网络,它们只是将不同的模块存放在一起,需要自己定义 forward 函数。 - 虽然
Sequential
等类可以使模型构造更加简单,但直接继承 Module 类可以极大地拓展模型构造的灵活性。