Opencv,Pytorch搭建DBnet文本检测网络

前言:

本文将使用pytorch从0实现DBnet文本检测网络,关于DBnet网络的基础这里不在详细介绍,只注重代码编写。如果需要训练自己的数据集,可以将自己的数据集制作成本文所使用的格式即可。

网络搭建:

网络结构图:
在这里插入图片描述
特征提取采用resnet18,dbnet网络的输入大小为 batch x 3 x 640 x 640,输出大小为batch x 3 x 640 x 640,其中3个通道分别代表probability map,threshold map,approximate binary map。
代码如下所示:

from torchvision import models
from torchsummary import summary
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn as nn
import torch

resnet18=models.resnet18(pretrained=False)

resnet18 = create_feature_extractor(
        resnet18, {
    
    'relu':'feature2','layer1': 'feature4', 'layer2': 'feature8',
              'layer3': 'feature16','layer4': 'feature32'})

class Upsample_add(nn.Module):
    def __init__(self,ins,outs,x1):
        super(Upsample_add, self).__init__()
        self.x1=x1
        self.up=nn.Sequential(
            nn.ConvTranspose2d(ins, outs, kernel_size=2, stride=2),
            nn.ReLU()
        )
    def forward(self,x):
        x=self.up(x)
        x=x+self.x1
        return x
class Upsample_N(nn.Module):
    def __init__(self,ins,outs,N):
        super(Upsample_N, self).__init__()
        self.up=nn.Sequential(
            nn.ConvTranspose2d(ins, outs, kernel_size=N, stride=N),
            nn.ReLU()
        )
        self.conv = nn.Sequential(
            nn.Conv2d(outs, outs, stride=1, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self,x):
        x=self.up(x)
        x=self.conv(x)
        return x
class DBnet(nn.Module):
    def __init__(self):
        super(DBnet, self).__init__()
        self.backbone=resnet18

        self.conv=nn.Sequential(
            nn.Conv2d(100,2,kernel_size=3,stride=1,padding=1),
            nn.Sigmoid()
        )


    def forward(self,x):
        dicts=self.backbone(x)
        x32=dicts['feature32']
        x16=self.block(512,256,dicts['feature16'],dicts['feature32'])
        x8=self.block(256,128,dicts['feature8'],x16)
        x4=self.block(128,64,dicts['feature4'],x8)

        x32=self.up_n(512,100,8,x32)
        x16=self.up_n(256,100,4,x16)
        x8=self.up_n(128,100,2,x8)
        x4=self.up_n(64,100,1,x4)

        x=torch.concat((x4,x8,x16,x32),dim=1)
        x=self.up_n(400,100,4,x)
        x=self.conv(x)


        P=x[:,0,...]
        T=x[:,1,...]
        B=1/(1+torch.exp(-50*(P-T)))
        B=torch.unsqueeze(B,dim=1)
        out=torch.concat((x,B),dim=1)

        return out
    def block(self,ins,outs,x1,x):
        block=Upsample_add(ins,outs,x1)
        return block(x)
    def up_n(self,ins,outs,N,x):
        upsam=Upsample_N(ins,outs,N)
        return upsam(x)

if __name__ == '__main__':
    dbnet=DBnet()
    torch.save(dbnet.state_dict(),'model.pth')
    out = dbnet(torch.rand(8, 3, 640, 640))
    print(out.shape)

数据集介绍:

使用ICPR文本检测数据集。
数据集下载地址:
https://tianchi.aliyun.com/competition/entrance/231685/information
数据集格式如下所示,每张图片都对应一个txt文本。
图片:
在这里插入图片描述
txt文本:
在这里插入图片描述
前面8个数字分别代表x1,y1,x2,y2,x3,y3,x4,y4。表示四边形的四个顶点坐标。

暂时先写到这里,后续将继续更新。。。。

猜你喜欢

转载自blog.csdn.net/Czqqwer/article/details/130071335