1.介绍
最先进的延迟-精度权衡
在两个广泛使用的平台——移动设备和桌面GPU上的延迟是最快的
1.1 之前存在的问题
由于内存访问成本的增加,跳过连接(skip connections)在延迟方面造成了很大的开销
1.1.1 跳过连接 为什么 会增加内存访问成本?
这主要与以下两个因素有关:
-
内存需求:跳过连接会引入额外的特征映射或张量,从而增加了模型内存的使用。每个跳过连接需要存储一定数量的特征映射,这可能导致更多的内存占用。虽然这种增加通常不会显著影响模型的总体内存占用,但在某些情况下可能会有一些额外的内存压力。
-
内存访问:在深度神经网络中,计算和存储特征映射需要大量的内存访问。跳过连接增加了需要读取和写入的特征映射数量,这可能会导致内存访问成本的增加。
1.2 FastViT 改进之处
- 引入了完全可重新参数化的RepMixer来删除跳过连接。
- 将所有密集的k×k卷积替换为它们的分解版本,即逐通道卷积和逐点卷积。使用了线性训练时间过参数化( linear train-time overparameterization)。这些额外的分支仅在训练期间引入,并在推理时重新参数化。
- 在早期阶段使用大卷积核来替代自关注层
1.3 前置知识
- 解耦 训练 和 推理
- 参数重参数化
- 使用大核卷积的好处(为什么可以使用大核卷积替代自关注层)
2. FastViT
(a)解耦训练时间和推理时间架构的FastViT架构概述。阶段1、2和3具有相同的体系结构,并使用RepMixer进行令牌混合。在阶段4中,自关注层用于令牌混合。(b)卷积系统的结构。(c)卷积- ffn的体系结构(d) RepMixer块概述,它在推理时重新参数化跳过连接。
基本内容已经在上图片上了
【论文解读参考】解读模型压缩24:FastViT:快速卷积 Transformer 的混合视觉架构 - 知乎 (zhihu.com)
3.代码
3.1 不同型号
FastViT 一共有6个型号。
3.2 RepCPE
在fastvit_sa12,fastvit_sa24,fastvit_sa36,fastvit_ma36等模型中
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
在推理时将原先的残差分支去除,只看前向推理函数
#reparam_conv 和 self.pe其实是参数基本相同的卷积
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
x = self.reparam_conv(x)
return x
else:
x = self.pe(x) + x
return x
3.3 basic_blocks
token_mixer有两种选择,RepMixerBlock和AttentionBlock,一般也只会在最后一个层结构中使用AttentionBlock
3.3.1 RepMixerBlock
和其他的Mixer结构相似
3.3.1.1 RepMixer
(1)训练结构
残差结构的实现是通过带分支结构的MobileOneBlock减去不带分支结构的MobileOneBlock实现的
x = x + self.mixer(x) - self.norm(x)
(2)重参数
重参数的原理和RepVGG相同,不过多赘述了
3.3.1.2 ConvFFN
其实也相对简单,值得一提的是ConvFFN似乎不进行重参数化
3.4 PatchEmbed
区别一般的,通过一个ReparamLargeKernelConv和一个MobileOneBlock进行映射
ReparamLargeKernelConv来自Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNshttps://openaccess.thecvf.com/content/CVPR2022/papers/Ding_Scaling_Up_Your_Kernels_to_31x31_Revisiting_Large_Kernel_Design_CVPR_2022_paper.pdf
class PatchEmbed(nn.Module):
"""Convolutional patch embedding layer."""
def __init__(
self,
patch_size: int,
stride: int,
in_channels: int,
embed_dim: int,
inference_mode: bool = False,
) -> None:
super().__init__()
block = list()
block.append(
ReparamLargeKernelConv(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=stride,
groups=in_channels,
small_kernel=3,
inference_mode=inference_mode,
)
)
block.append(
MobileOneBlock(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=1,
stride=1,
padding=0,
groups=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
)
)
self.proj = nn.Sequential(*block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x