5.一脚踹进ViT——Swin Transformer(上)

Swin Transformer

解决ViT对于下游任务不友好的问题,提出了滑动窗口

Swin的特点:

  1. 从小Patch开始,逐层合并相邻Patch

  2. 计算Window Attention

  3. 提出Shifted Window操作,更有效计算Attention

1.论文阅读笔记

Swin Transformer中使用移动窗口构建了层级式的ViT,让ViT像CNN一样也能够分成几个block,能做层级式的特征提取,对图像大小具有线性计算复杂度。

1.1 摘要:

指出问题Transformer从NLP用到视觉中的问题:

  1. 尺度太大(街景中有行人和汽车,有各种各样的尺寸,但在NLP中不存在)

  2. resolution太大导致序列太长,计算量很大

之前解决办法:

  1. 用后续特征图作为Transformer的输入

  2. 将图片打成Patch,减少图片的resolution

  3. 将图片划成一个个的小窗口,在窗口中做自注意力。

本文提出移动窗口,不仅减少了计算量,因为有移动操作,能让相邻的两个窗口之间有了交互,于是上下层之间有了cross-window的连接,这种层级式结构的好处不仅灵活,提供各个尺度的信息,同时自注意力在小窗口内算的,计算复杂度随着图像大小而线性增长的。

1.2 结论:

在这里插入图片描述

使用小窗口内计算自注意力,而不同与ViT在整图上算自注意力,只要窗口大小是固定的,那SA的复杂度就是固定的,整张图的计算复杂度就随着图像大小呈线性增长关系,图像变为x倍,窗口数量就增大x倍,复杂度就是x倍,而非x的平方。

利用的CNN中局部性的归纳偏置,同一个物体的不同部位(语义相近的不同物体),还是大概率会出现在相连的地方,即使在小范围的窗口中也是够用的,全局中算注意力可能是浪费资源的。

CNN中多尺寸的特征是由池化操作,能够增大卷积核能看到的感受野,从而使得每次池化过后的特征抓住物体的不同尺寸,本文提出了patch merging,使得相邻的patch合成一个大patch,使得感受野增大,抓住多尺度特征。有了4×,8×,16×这种多尺度特征,扔给FPN就能做检测了,扔给UNET就能做分割了,所以Swin Transformer能当作骨干网络来做。

在这里插入图片描述

划分完后,窗口与窗口之间可以互动

在这里插入图片描述

假如一张图片224×224×3,首先打成4×4的patch,每个patch中图像大小变为原来1/4,即56,维度就变为4×4×3;接下来Linear Embedding把向量维度变为预先设置好的值(Swin Transformer能接受的值),超参数C=96,走完就变为56×56×96,拉直后就变为3136 ×96,而ViT中序列长度是16×16,而此时是3136非常大,本文是基于窗口来算的,窗口中仅有7×7=49个patch,暂且将它当作黑盒,在其中做了自注意力操作,如果不对其做约束,输入输出的尺寸是不变的,即输出还是56×56×96。

在Patch Merging中使用两个1×1的卷积来降维,将通道数由4C变为2C,目的是:与池化层相同,使得图像大小翻倍,通道数减半。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

2.Swin Transformer架构分析

在这里插入图片描述

类似ViT架构,对于输入图像进行Patch Partition(图像分块)、Patch Embedding,然后经过4个stage,类似于ResNet中的Stage,在Stage中主要由 Swin Transformer Block 构成,最后进行一个Patch Merging进行融合。

最关键的操作就是Swin Transformer BlockPatch Merging

上图展示了模型的大致结构图,我们还可以关心一下数据的流动

在这里插入图片描述

将图中拆分来看

1.以patch输入到网络后,如果是彩色图像,它的channel就为3,经过Patch Embedding后,通道数就变为embed_dim

在这里插入图片描述

2.得到Patch Embedding后,再使用 窗口(Windows)再切一次patch,我们目前的输入已经是feature level的tensor,做一次Windows Partition切成不重叠的窗口

在这里插入图片描述

3.如果没有划分window,我们做的就是每一个patch与其他所有batch,现在划分后,现在在每一个窗口中单独去做即可,可以减少计算量,不需要算每个batch与其他batch,我们经过attention后,输出维度和输入一样,这样对每个单独的window做完后,最终维度同样尺寸大小的tensor

在这里插入图片描述

4.Patch Merging

在这里插入图片描述

swin transformer中就是对相邻的4个image token融合起来,空间上尺寸变小,同时会将embed_dim的维度扩大2倍

5.Next Stage

在这里插入图片描述

一个Stage部分做完后,走到了下一个Stage,此时输入是merge之后更小的输入,继续重复上述步骤,切window,减少尺寸,升高维度

在这里插入图片描述

在某些Stage中,重复做多次Block块,但是其中不会改变尺寸,输入输出维度一直不变

2.1 Swin Transformer Block

本部分介绍Block是如何构建的

在这里插入图片描述

在这里插入图片描述

W-MSA(Window Multi-head Self Attention)和SW-MSA(Shifted Window Multi-head Self Attention)构成,本文先不介绍移动窗口部分,仅看左侧如何处理,数据进入后通过LN层,然后再走W-MSA,进行残差连接,再走LN,MLP,残差融合,与之前的差不多,就是需要修改W-MSA部分

W-MSA

在这里插入图片描述

Tensor通过划分Window操作后,只拿自己的Window出来,把其中16个token提出来,来做Attention;之后再拿window2自己的16个token做attention,每一个分开做

在这里插入图片描述

论文说W-MSA比MSA计算量小,推算一下公式

在这里插入图片描述

可以看到两个公式第二项不一样,MSA中随着h·w尺寸/(patch_num)呈平方级增长,但在W-MSA中呈线性关系,如果图像切图像尺寸越小,使用W-MSA效率会更高。

2.2 Patch Merging

在这里插入图片描述

将同一个window中画的颜色不同四个部分,排到一起,将每一部分的tokens再并列,原来merge后得到小window的dim会变为原来的4倍,此时token的数量会变为原来的1/4,之后在做一次,将其映射到2倍。

最后对映射后的内容reshape回去,于是长宽均变为原来1/2,维度变为2倍

3. 代码实现

涉及W-MSA 和 Patch Merging 以及 Window Partition

在这里插入图片描述

Window Partition是将我们的tensor切成window,然后送到attention中去算,所以有三个QKV。假设我们一个batch有3个样本,每个样本尺寸是一样的,都要把红框的window切出来,每个window单独做attention。

可以把一个batch的所有小window拼到一起去,所有的window直接是没关系的,window只管自己的,不管怎么排列,我只算自己的

在这里插入图片描述

我们看到的一个方块的一个小格,要它与算窗口内其他所有格的attention,这叫做window_attention,然后将每个小窗口的16个token拉出来,展开,即

在这里插入图片描述

import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self,patch_size=4,embed_dim=96):
        super().__init__()
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size,stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)      #[n, embed_dim, h', w']
        x = x.flatten(2)       #[n, embed_dim, h'w']
        x = x.permute(0, 2, 1)  # [n,  h'*w', embed_dim]
        x = self.norm(x)
        return x

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear( 4 * dim ,2 * dim)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        h, w = self.resolution
        b, _, c = x.shape   # _ 不用,其是 num_patches,即 h*w
        x = x.reshape([b, h, w, c])

        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, :]

        x = torch.concat([x0, x1, x2, x3], axis=-1)  # [B, h/2, w/2, 4c]
        x = x.reshape([b, -1, 4*c])
        x = self.norm(x)
        x = self.reduction(x)

        return x

class Mlp(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
        self.fc2 = nn.Linear(int(dim * mlp_ratio),dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

def windows_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.reshape([B, H//window_size,window_size, W//window_size, window_size, C])
    x = x.permute([0,1, 3, 2, 4, 5])
    # [B, h//ws, w//ws, ws, ws, c]
    x = x.reshape([-1, window_size, window_size, C])
    # [B * num_patches, ws, ws, c]
    return x


def windows_reverse(windows, window_size, H, W):
    B = int(windows.shape[0]// (H/window_size * W/window_size))
    x = windows.reshape([B, H//window_size, W//window_size, window_size, window_size, -1])
    x = x.permute([0, 1 ,3, 2, 4, 5])
    x= x.reshape([B, H, W, -1])
    return x

定义了WindowAttention,我们将它组合起来

在这里插入图片描述

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.dim_head = dim// num_heads
        self.num_heads = num_heads
        self.scale = self.dim_head ** -0.5
        self.softmax = nn.Softmax(-1)
        self.qkv = nn.Linear(dim,
                             dim * 3)
        self.proj = nn.Linear(dim, dim)

    def tranpose_multi_head(self, x):
        new_shape = x.shape[:-1] + (self.num_heads, self.dim_head)
        x = x.reshape(new_shape)
        x = x.permute(0, 2, 1, 3)  #[B, num_heads, num_patches, dim_head]
        return x

    def forward(self,x):
        # x: [B, num_patches, embed_dim]
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, -1)
        q, k, v = map(self.tranpose_multi_head, qkv)

        q = q * self.scale
        attn = torch.matmul(q, k.transpose(-1,-2))
        attn = self.softmax(attn)

        out = torch.matmul(attn, v)   # [B, num_heads, num_patches, dim_head]
        out = out.permute([0, 2, 1, 3])
        #  # [B, num_patches, num_heads, dim_head]    num_heads * dim_head= embed_dim
        out = out.reshape([B, N, C])
        out = self.proj(out)
        return out


class SwinBlock(nn.Module):
    def __init__(self, dim, input_reslution, num_heads, window_size):
        super().__init__()
        self.dim = dim
        self.reolution = input_reslution
        self.window_size =window_size

        self.attn_norm = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size,num_heads)

        self.mlp_norm = nn.LayerNorm(dim)
        self.mlp = Mlp(dim)

    def forward(self,x):
        H, W = self.reolution
        B, N, C =x.shape


        h = x
        s = self.attn_norm(x)
        #切 window
        x =  x.reshape([B, H, W, C])
        x_windows = windows_partition(x, self.window_size)
        # [B * num_patches, ws, ws, c]
        x_windows = x_windows.reshape([-1,self.window_size*self.window_size, C])

        attn_windows = self.attn(x_windows)

        # 做完attention 将它复原
        attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])
        x = windows_reverse(attn_windows, self.window_size, H, W)
        # [B, H ,W ,C]
        # 但是做mlp中 输入不是它
        x = x.reshape([B, H*W, C])

        x = h + x

        h = x

        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = h + x
        return x

最终用一个主函数去调用


def main():
    t = torch.randn([4, 3, 224, 224])
    patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
    swinBlock = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7)
    patch_merging = PatchMerging(input_resolution=[56,56], dim=96)

    out = patch_embedding(t)  #[4, 56, 56, 96]
    print('path_embedding out shape= ',out.shape)
    out = swinBlock(out)
    print('swinBlock out shape= ',out.shape)
    out = patch_merging(out)
    print('patch_merging out shape= ',out.shape)



if __name__ == '__main__':
    main()

在这里插入图片描述

首先我们输入的是一个batch的data,[4, 3, 224, 224],batch_size为4,我们通过patch_embedding操作,取一定大小的patch,patch_size 为4,所以变换完后tensor变为[4,56,56,96],3136是为了下一步做attention的时候方便。

在swinBlock,其中是做了windows_partition、WindowAttention,其中是不变换维度大小的

最后做了patch_merging,类似池化,相邻的4个token合并,维度扩大了2倍,即96变为192,而784是28×28,即56×56缩小了两倍

在这里插入图片描述

所以WindowAttention主要也是reshape再变回去

猜你喜欢

转载自blog.csdn.net/qq_45807235/article/details/129178939