LPRnet中maxpool3d无法成功转ONNX改写

使用yolov5+LRPnet进行道路车牌检测和识别,车牌识别模型训练好后使用torch.onnx.export导出的时候遇到了报错,主要问题就是ONNX不支持maxpool3d算子,无法直接进行转换,本文主要针对这一块进行处理

1、LPRnet主干代码

从LPRnet的主代码中可以得知,其中主要有三个地方用到了nn.MaxPool3d,这是本次模型转onnx无法成功的关键,而nn.MaxPool3d只是求取区域内的最大值,不会造成权重的更新,因此只需要更新代码实现方式,获得相对应的输出即可

class LPRNet(nn.Module):
    def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
        super(LPRNet, self).__init__()
        self.phase = phase
        self.lpr_max_len = lpr_max_len
        self.class_num = class_num
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),  # 0  [bs,3,24,94] -> [bs,64,22,92]
            nn.BatchNorm2d(num_features=64),  # 1  -> [bs,64,22,92]
            nn.ReLU(),  # 2  -> [bs,64,22,92]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),  # 3  -> [bs,64,20,90]
            small_basic_block(ch_in=64, ch_out=128),  # 4  -> [bs,128,20,90]
            nn.BatchNorm2d(num_features=128),  # 5  -> [bs,128,20,90]
            nn.ReLU(),  # 6  -> [bs,128,20,90]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),  # 7  -> [bs,64,18,44]
            small_basic_block(ch_in=64, ch_out=256),  # 8  -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 9  -> [bs,256,18,44]
            nn.ReLU(),  # 10 -> [bs,256,18,44]
            small_basic_block(ch_in=256, ch_out=256),  # 11 -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 12 -> [bs,256,18,44]
            nn.ReLU(),  # 13 -> [bs,256,18,44]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)),  # 14 -> [bs,64,16,21]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate                          # 15 -> [bs,64,16,21]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),  # 16 -> [bs,256,16,18]
            nn.BatchNorm2d(num_features=256),  # 17 -> [bs,256,16,18]
            nn.ReLU(),  # 18 -> [bs,256,16,18]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate                                  19 -> [bs,256,16,18]
            nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1),
            # class_num=68  20  -> [bs,68,4,18]
            nn.BatchNorm2d(num_features=class_num),  # 21 -> [bs,68,4,18]
            nn.ReLU(),  # 22 -> [bs,68,4,18]
        )
        self.container = nn.Sequential(
            nn.Conv2d(in_channels=448 + self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)),
           

    def forward(self, x):
        keep_features = list()
        for i, layer in enumerate(self.backbone.children()):
            x = layer(x)
            if i in [2, 6, 13, 22]:
                keep_features.append(x)

        global_context = list()
        # keep_features: [bs,64,22,92]  [bs,128,20,90] [bs,256,18,44] [bs,68,4,18]
        for i, f in enumerate(keep_features):
            if i in [0, 1]:
                # [bs,64,22,92] -> [bs,64,4,18]
                # [bs,128,20,90] -> [bs,128,4,18]
                f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
            if i in [2]:
                # [bs,256,18,44] -> [bs,256,4,18]
                f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)

            f_pow = torch.pow(f, 2)  # [bs,64,4,18] 
            f_mean = torch.mean(f_pow)  # 
            f = torch.div(f, f_mean)  # [bs,64,4,18]  
            global_context.append(f)

        x = torch.cat(global_context, 1)  # [bs,516,4,18]
        x = self.container(x)  # -> [bs, 68, 4, 18] 
        logits = torch.mean(x, dim=2)  # -> [bs, 68, 18] 

        return logits

2、修改nn.MaxPool3d代码

以下是batch为1的时候完成的代码,主要核心思想是使用一个MaxPool2d和一个MaxPool1d替代nn.MaxPool3d

class MaxPool3d(torch.nn.Module):
    def __init__(self, kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)):
        super(MaxPool3d_modify, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.max_pool_2d = torch.nn.MaxPool2d(kernel_size[1:], self.stride[1:], padding[1:])
        self.max_pool_1d = torch.nn.MaxPool1d(kernel_size=kernel_size[0], stride=self.stride[0],padding=self.padding[0])  # 

    def forward(self, x):
        x1 = self.max_pool_2d(x)
        x = x1.squeeze(0).permute(1, 2, 0)
        x = self.max_pool_1d(x)
        x = x.permute(2, 0, 1).unsqueeze(0)
        return x

多batch的方案替代如下所示:

class MaxPool3d_muti_batch(torch.nn.Module):
    # for multi one batch
    def __init__(self, kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)):
        super(MaxPool3d_muti_batch, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.max_pool_2d = torch.nn.MaxPool2d(kernel_size[1:], self.stride[1:], padding[1:])
        self.max_pool_1d = torch.nn.MaxPool1d(kernel_size=kernel_size[0], stride=self.stride[0],
                                              padding=self.padding[0])  # stride is kernal_size

    def forward(self, x):
        batch = x.size()[0]
        res = []
        for i in range(batch):
            x_temp = x[i].unsqueeze(0)
            x_temp = self.max_pool_2d(x_temp)
            x_temp = x_temp.squeeze(0).permute(1, 2, 0)
            x_temp = self.max_pool_1d(x_temp)
            x_temp = x_temp.permute(2, 0, 1).unsqueeze(0)
            res.append(x_temp)
        return torch.cat(res, dim=0)

可通过如下代码进行测试验证输出是否相同:

    input = torch.rand(1, 3, 24, 94)
    model1 = torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 0, 0))
    model2 = MaxPool3d_muti_batch(kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 0, 0))
    model3 = MaxPool3d_modify(kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 0, 0))
    out1 = model1(input)
    out2 = model2(input)
    out3 = model3(input)
    print(out1.shape)
    print(out2.shape)
    print(out3.shape)
    res = torch.eq(out1, out2)
    res2 = torch.eq(out1,out3)
    if torch.all(res):
        print("out1 is equal to out2")
    else:
        print("out1 is not equal to out2")

    if torch.all(res2):
        print("out1 is equal to out3")
    else:
        print("out1 is not equal to out3")

3、替换nn.MaxPool3d

将LRPnet中所有出现的nn.MaxPool3d的地方替换为:MaxPool3d_muti_batch模块,可以发现输出是一致的
替换代码如下:

class LPRNet_modify(nn.Module):
    def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
        super(LPRNet_modify, self).__init__()
        self.phase = phase
        self.lpr_max_len = lpr_max_len
        self.class_num = class_num
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),  # 0  [bs,3,24,94] -> [bs,64,22,92]
            nn.BatchNorm2d(num_features=64),  # 1  -> [bs,64,22,92]
            nn.ReLU(),  # 2  -> [bs,64,22,92]
            MaxPool3d_muti_batch(kernel_size=(1, 3, 3), stride=(1, 1, 1)),  # 3  -> [bs,64,20,90]
            small_basic_block(ch_in=64, ch_out=128),  # 4  -> [bs,128,20,90]
            nn.BatchNorm2d(num_features=128),  # 5  -> [bs,128,20,90]
            nn.ReLU(),  # 6  -> [bs,128,20,90]
            MaxPool3d_muti_batch(kernel_size=(1, 3, 3), stride=(2, 1, 2)),  # 7  -> [bs,64,18,44]
            small_basic_block(ch_in=64, ch_out=256),  # 8  -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 9  -> [bs,256,18,44]
            nn.ReLU(),  # 10 -> [bs,256,18,44]
            small_basic_block(ch_in=256, ch_out=256),  # 11 -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 12 -> [bs,256,18,44]
            nn.ReLU(),  # 13 -> [bs,256,18,44]
            MaxPool3d_muti_batch(kernel_size=(1, 3, 3), stride=(4, 1, 2)),  # 14 -> [bs,64,16,21]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate                          # 15 -> [bs,64,16,21]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),  # 16 -> [bs,256,16,18]
            nn.BatchNorm2d(num_features=256),  # 17 -> [bs,256,16,18]
            nn.ReLU(),  # 18 -> [bs,256,16,18]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate                                  19 -> [bs,256,16,18]
            nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1),
            # class_num=68  20  -> [bs,68,4,18]
            nn.BatchNorm2d(num_features=class_num),  # 21 -> [bs,68,4,18]
            nn.ReLU(),  # 22 -> [bs,68,4,18]
        )
        self.container = nn.Sequential(
            nn.Conv2d(in_channels=448 + self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)),
        )

通过以上操作完成模型训练后,直接可以导出onnx成功。

猜你喜欢

转载自blog.csdn.net/caobin_cumt/article/details/129814831