前言:
本文将使用pytorch从0实现DBnet文本检测网络,关于DBnet网络的基础这里不在详细介绍,只注重代码编写。如果需要训练自己的数据集,可以将自己的数据集制作成本文所使用的格式即可。
网络搭建:
网络结构图:
特征提取采用resnet18,dbnet网络的输入大小为 batch x 3 x 640 x 640,输出大小为batch x 3 x 640 x 640,其中3个通道分别代表probability map,threshold map,approximate binary map。
代码如下所示:
from torchvision import models
from torchsummary import summary
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn as nn
import torch
resnet18=models.resnet18(pretrained=False)
resnet18 = create_feature_extractor(
resnet18, {
'relu':'feature2','layer1': 'feature4', 'layer2': 'feature8',
'layer3': 'feature16','layer4': 'feature32'})
class Upsample_add(nn.Module):
def __init__(self,ins,outs,x1):
super(Upsample_add, self).__init__()
self.x1=x1
self.up=nn.Sequential(
nn.ConvTranspose2d(ins, outs, kernel_size=2, stride=2),
nn.ReLU()
)
def forward(self,x):
x=self.up(x)
x=x+self.x1
return x
class Upsample_N(nn.Module):
def __init__(self,ins,outs,N):
super(Upsample_N, self).__init__()
self.up=nn.Sequential(
nn.ConvTranspose2d(ins, outs, kernel_size=N, stride=N),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(outs, outs, stride=1, kernel_size=3, padding=1),
nn.ReLU()
)
def forward(self,x):
x=self.up(x)
x=self.conv(x)
return x
class DBnet(nn.Module):
def __init__(self):
super(DBnet, self).__init__()
self.backbone=resnet18
self.conv=nn.Sequential(
nn.Conv2d(100,2,kernel_size=3,stride=1,padding=1),
nn.Sigmoid()
)
def forward(self,x):
dicts=self.backbone(x)
x32=dicts['feature32']
x16=self.block(512,256,dicts['feature16'],dicts['feature32'])
x8=self.block(256,128,dicts['feature8'],x16)
x4=self.block(128,64,dicts['feature4'],x8)
x32=self.up_n(512,100,8,x32)
x16=self.up_n(256,100,4,x16)
x8=self.up_n(128,100,2,x8)
x4=self.up_n(64,100,1,x4)
x=torch.concat((x4,x8,x16,x32),dim=1)
x=self.up_n(400,100,4,x)
x=self.conv(x)
P=x[:,0,...]
T=x[:,1,...]
B=1/(1+torch.exp(-50*(P-T)))
B=torch.unsqueeze(B,dim=1)
out=torch.concat((x,B),dim=1)
return out
def block(self,ins,outs,x1,x):
block=Upsample_add(ins,outs,x1)
return block(x)
def up_n(self,ins,outs,N,x):
upsam=Upsample_N(ins,outs,N)
return upsam(x)
if __name__ == '__main__':
dbnet=DBnet()
torch.save(dbnet.state_dict(),'model.pth')
out = dbnet(torch.rand(8, 3, 640, 640))
print(out.shape)
数据集介绍:
使用ICPR文本检测数据集。
数据集下载地址:
https://tianchi.aliyun.com/competition/entrance/231685/information
数据集格式如下所示,每张图片都对应一个txt文本。
图片:
txt文本:
前面8个数字分别代表x1,y1,x2,y2,x3,y3,x4,y4。表示四边形的四个顶点坐标。
暂时先写到这里,后续将继续更新。。。。