IDDPM代码ResBlock和TimestepEmbedSequential解读
ResBlock中的forward和_forward的区别
# ResBlock是为了把embedding以残差的形式和图片加起来,即把时间信息融合到图片中去
class ResBlock(TimestepBlock):
# resblock是继承自timestepblock的,所以所有的resblock部分肯定是要传入embedding的
# 而在attention, 上采样,下采样都不需要传入embedding
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels: # 如果通道数目一致的话,直接连起来就好
self.skip_connection = nn.Identity()
elif use_conv:# 如果通道数目不一致的话,可以用一个大小不变的卷积去做
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:# 在有的论文中,如果通道数目不一致的话,也可以用一个1*1的卷积去做逐点的卷积
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb): # _forward 是私有方法,它执行实际的计算并将其结果返回给 forward
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift # ys = (1+scale), yb = bias
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h # identity
在这个类中,forward 是公开接口,用于将输入传递给该模块的子模块并返回结果。 _forward 是私有方法,它执行实际的计算并将其结果返回给 forward。在这个类中,我们可以看到 forward 方法调用了 checkpoint,以利用 PyTorch 的自动微分机制来减少内存的使用。然后,_forward 方法执行了所有计算,并将最终结果返回给 forward。因此,我们可以说 forward 方法是 ResBlock 类的外部接口,而 _forward 方法是其内部实现。
TimestepEmbedSequential
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
# emb:timestep embedding和condition embedding混合起来的
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):# 只有layer是timestepblock的时候才输入emb
x = layer(x, emb)
else:
x = layer(x)
return x
其调用语句为
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
) conv_nd的结果是怎么传入TimestepEmbedSequential的?
在调用TimestepEmbedSequential时,我们传递给它一个nn.Module对象,该模块是由conv_nd(dims, in_channels, model_channels, 3, padding=1)创建的。conv_nd是一个工厂函数,它根据给定的参数创建一个卷积层对象。在这里,conv_nd返回一个维度为dims(dims是一个整数,表示卷积层的维度,例如1表示一维卷积,2表示二维卷积,3表示三维卷积),输入通道数为in_channels,输出通道数为model_channels,卷积核大小为3的卷积层对象,并将其作为参数传递给TimestepEmbedSequential的构造函数。
在TimestepEmbedSequential的forward方法中,该卷积层对象将作为nn.Sequential的一个子模块来使用,即被添加到self列表中。在调用forward方法时,输入张量x和时间步骤嵌入张量emb将依次被传递给nn.Sequential中的每个子模块。当遇到一个子模块是TimestepBlock类型时,emb将被传递给该子模块的forward方法,作为额外的输入。在这个例子中,conv_nd不是TimestepBlock类型,所以emb将被忽略,仅传递x给该层的forward方法。