nn.Sequential 和 nn.ModuleList()的联系与区别

nn.Sequentialnn.ModuleList()PyTorch 中用于管理神经网络模型中的子模块的两种不同的方式。

nn.Sequential 是一个用于构建顺序模型的容器类。它允许按照给定的顺序添加一系列的子模块,并将它们串联在一起形成一个顺序的网络结构。nn.Sequential 可以简化模型的定义和前向传播的编写,特别适用于那些没有复杂控制流程的简单网络结构。通过向 nn.Sequential 中添加子模块,这些子模块会自动按照添加的顺序连接在一起,并形成一个整体的模型。在调用 nn.Sequentialforward 方法时,输入数据将按照添加的顺序经过每个子模块,从而实现整个模型的前向传播。

示例使用 nn.Sequential 构建一个简单的模型:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)

在这个示例中,我们通过 nn.Sequential 定义了一个顺序模型。顺序模型包含三个子模块:一个线性层、一个 ReLU 激活函数和另一个线性层。当我们调用模型的 forward 方法时,输入数据 input_tensor 将按照添加的顺序依次经过每个子模块,并生成输出数据 output_tensor


相比之下,nn.ModuleList() 是一个类似于 Python 列表的容器,用于存储和管理任意数量的子模块。与 nn.Sequential 不同的是,nn.ModuleList() 并不自动连接子模块,而是将其存储为列表的形式。因此,在使用 nn.ModuleList() 定义模型时,我们需要自己定义子模块之间的连接关系。这使得 nn.ModuleList() 更加灵活,适用于那些具有复杂控制流程或需要自定义连接方式的网络结构。

示例使用 nn.ModuleList() 构建一个简单的模型:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.module_list = nn.ModuleList([
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10)
        ])

    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        return x

model = MyModel()
input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)

在这个示例中,我们定义了一个自定义的模型类 MyModel,其中使用了 nn.ModuleList() 来存储三个子模块:一个线性层、一个ReLU 激活函数和另一个线性层。在模型的 forward 方法中,我们通过迭代 module_list 中的子模块,依次将输入数据 x 传递给它们,并获取最终的输出。

因此,nn.Sequentialnn.ModuleList() 的区别在于自动连接子模块的能力。nn.Sequential 自动按照添加的顺序连接子模块,适用于简单的顺序模型。而 nn.ModuleList() 则需要手动定义子模块之间的连接方式,适用于具有复杂控制流程或自定义连接的模型。

此外,nn.Sequential 还提供了更简洁的语法来定义模型,因为它可以直接通过传入子模块的列表来创建模型。而 nn.ModuleList() 则需要显式地在模型类中定义和初始化子模块。

nn.Sequentialnn.ModuleList() 都是 nn.Module 的子类,因此它们都可以作为模型的属性进行注册和管理。

猜你喜欢

转载自blog.csdn.net/AdamCY888/article/details/131270539