pytroch学习(二十五)—目标检测(数据集制作)

前言

之前,测试通了pytroch版的yolo-v2/v3, ssd-mobilenetv1/v2目标检测代码。 相对于测试,如何用自己的数据训练一个目标检测模型才更令人兴奋。俗话曰:兵马未动,粮草先行, 在训练之前,首先需要准备好训练数据。

在许多例子中,一般都用VOC, COCO格式的数据集进行训练和测试。对于我们自己的数据,一般不是VOC/COCO格式的数据,所以一个比较笨的方法就是写一个脚本进行数据格式转换,再不济可以手动创建文件夹,直接把相应的数据复制到制定的目录,这样很麻烦。麻烦的对方主要在于:1. VOC中标签都是1张图像对应一个xml文件, xml结构数据本身相对解析麻烦,不如JSON,YAML轻巧。 2. 电脑中需要将原始数据复制2份,一份用作VOC格式数据, 另一份是原始数据。

下面将直接使用原始数据,使用pytroch提供的类对数据进行简单封装,实现数据集的索引和读取。然后再转换为VOC格式的数据。

快速起见, 采用一个公开数据集,Wider Face, 这个数据集用于做人脸检测,训练集合包含12k的图像,而且提供人脸矩形框标签


目的

  • pytroch实现对widerFace数据的API封装
  • 将widerFace数据转换为VOC格式的数据

开发环境

  • Ubuntu 18.04
  • pycharm
  • Anaconda3, python3.6
  • pytroch 1.0, torchvision

widerFace 人脸检测数据集

标签

11478104-b937925dc63fcdba.png
image.png

11478104-ce28a119d490c6c6.png
image.png
11478104-6bb2ba069530fdc7.png
image.png

训练集

简单起见,将wider_face_train_bbx_gt.txt 复制到训练集合所在路径, images文件夹包含图像。

11478104-9007015e787fa815.png
image.png
11478104-179358c9c6741d37.png
image.png

使用Pytroch Dataset API进行封装

  • 代码
    在pytroch中,数据集定义很简单,按照pytroch提供的套路就可以。 一般的, 首先定义一个类继承troch.utils.data.Dataset, 然后override __len()__, __getitem()__ 方法。
  1. __len()__ : 返回数据集容量大小

  2. __getitem()__: 返回数据集迭代时候每一个样本及其标签数据。

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transfroms
import matplotlib.pyplot as plt
import os
import PIL.Image as Image
import PIL
import cv2
import numpy as np

class WiderFaceDataset(Dataset):
    def __init__(self, images_folder, ground_truth_file, transform=None, target_transform=None):
        super(WiderFaceDataset, self).__init__()
        self.images_folder = images_folder
        self.ground_truth_file = ground_truth_file
        self.images_name_list = []
        self.ground_truth = []
        with open(ground_truth_file, 'r') as f:
            for i in f:
                self.images_name_list.append(i.rstrip())
                self.ground_truth.append(i.rstrip())

        self.images_name_list = list(filter(lambda x: x.endswith('.jpg') or x.endswith('.bmp'),
                                       self.images_name_list))

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images_name_list)

    def __getitem__(self, index):
        image_name = self.images_name_list[index]
        # 查找文件名
        loc = self._search(image_name)
        # 解析人脸个数
        face_nums = int(self.ground_truth[loc + 1])
        # 读取矩形框
        rects = []
        for i in range(loc + 2, loc + 2 + face_nums):
            line = self.ground_truth[i]
            x, y, w, h = line.split(' ')[:4]
            x, y, w, h = list(map(lambda k: int(k), [x, y, w, h]))
            rects.append([x, y, w, h])

        # 图像
        image = PIL.Image.open(os.path.join(self.images_folder, image_name))

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            rects = list(map(lambda x: self.target_transform(x), rects))

        return {'image': image, 'label': rects, 'image_name': os.path.join(self.images_folder, image_name)}

    def _search(self, image_name):
        for i, line in enumerate(self.ground_truth):
            if image_name == line:
                return i


if __name__ == '__main__':
   images_folder = '/media/weipenghui/Extra/WiderFace/WIDER_train/images'
   ground_truth_file = open('/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt', 'r')

   dataset = WiderFaceDataset(images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
                              ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt',
                              transform=transfroms.ToTensor(),
                              target_transform=lambda x: torch.tensor(x))

   var = next(iter(dataset))
   image_transformed = var['image']
   label_transformed = var['label']
   image_name = var['image_name']
   #plt.figure()
   image_transformed = image_transformed.numpy().transpose((1, 2, 0))
   image_transformed = np.floor(image_transformed * 255).astype(np.uint8)
   image = cv2.imread(image_name)
   for rect in label_transformed:
       x, y, w, h = rect
       x, y, w, h = list(map(lambda k: k.item(), [x, y, w, h]))
       cv2.rectangle(image, pt1=(x, y), pt2=(x + w, y + h),color=(255,0,0))

   cv2.imshow('image',image)
   cv2.waitKey(0)
   plt.imshow(image_transformed)
   plt.show()

   # for i, sample in enumerate(dataset):
   #     print(i, sample['image'])
   # 
   # print(len(dataset))


11478104-8acd579f502832ea.png
image.png
11478104-dbe5c55c4d60d700.png
image.png
11478104-0eb03b9007a3585d.png
image.png
11478104-5ca2d722984951f4.png
image.png

widerFace 转换为VOC格式数据

VOV格式

11478104-9fe9dd7454c3ce4e.png
image.png
11478104-de25eed8ba9e94d5.png
image.png
11478104-a6ac6c6c8312b5ca.png
image.png

VOC数据的目录

11478104-e73748358572af6a.png
image.png

标签xml文件

11478104-cd6ab5df59c43f84.png
image.png

图像

11478104-3523784bce0e551a.png
image.png

xml格式

<annotation>
    <folder>VOC2007</folder>
    <filename>000001.jpg</filename>
    <source>
        <database>The VOC2007 Database</database>
        <annotation>PASCAL VOC2007</annotation>
        <image>flickr</image>
        <flickrid>341012865</flickrid>
    </source>
    <owner>
        <flickrid>Fried Camels</flickrid>
        <name>Jinky the Fruit Bat</name>
    </owner>
    <size>
        <width>353</width>
        <height>500</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <object>
        <name>dog</name>
        <pose>Left</pose>
        <truncated>1</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>48</xmin>
            <ymin>240</ymin>
            <xmax>195</xmax>
            <ymax>371</ymax>
        </bndbox>
    </object>
    <object>
        <name>person</name>
        <pose>Left</pose>
        <truncated>1</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>8</xmin>
            <ymin>12</ymin>
            <xmax>352</xmax>
            <ymax>498</ymax>
        </bndbox>
    </object>
</annotation>


VOC转换过程

VOC的标签采用xml文件表示,因此我们需要将图像的标签写入到xml文件中。参考了许多资料,好多都是不完整的代码,于是就参考着写了一个xml生成的代码。

为了简单一些,再定义一个类表示数据集中的一个样本WiderFaceSample。 使用WiderFace进行人脸检测,目标只有2类,即人脸和背景, 因此在xml中直接将object/name 属性固定写为face。

import dataset
import numpy as np
import os
import shutil
from lxml.etree import Element, SubElement, tostring
import pprint
from xml.dom.minidom import parseString
from xml.dom.minidom import Document


class WiderFaceSample:

    def __init__(self):
        self.face_rects = []
        self.image_name = ''
        self.image = None
        self.image_width = 0
        self.image_height = 0

    def save_image(self, folder_path, new_name):
        shutil.copy(src=self.image_name, dst=os.path.join(folder_path, new_name))

    def save_label_to_txt(self, folder_path, new_name):
        with open(os.path.join(folder_path, new_name), 'w') as f:
            for rect in self.face_rects:
                f.write('{} {} {} {}\n'.format(rect[0], rect[1], rect[2], rect[3]))

    def save_label_to_xml(self, folder_path, new_xml_name, new_image_name):
        doc = self._generate_xml('WiderFace_VOC', new_image_name,self.face_rects)
        var = doc.toprettyxml(indent='\t', encoding='utf-8')
        with open(os.path.join(folder_path, new_xml_name), 'w') as f:
            f.write(var.decode())

    def _generate_xml(self, folder_str, filename_str, face_rects):

        # https://www.cnblogs.com/haigege/p/5712854.html
        # https://www.cnblogs.com/zjutzz/p/6847848.html
        # https://www.cnblogs.com/qw12/p/6185126.html

        doc = Document()

        # root
        annotation = doc.createElement('annotation')

        doc.appendChild(annotation)

        # -----------folder-----------------------
        folder = doc.createElement('folder')
        folder_text = doc.createTextNode(folder_str)
        folder.appendChild(folder_text)
        annotation.appendChild(folder)

        # ------------filename-------------------
        filename = doc.createElement('filename')
        filename_text = doc.createTextNode(filename_str)
        filename.appendChild(filename_text)
        annotation.appendChild(filename)

        # -------------size--------------------
        size = doc.createElement('size')
        width_text = doc.createTextNode(str(self.image_width))
        height_text = doc.createTextNode(str(self.image_height))
        depth_text = doc.createTextNode(str(3))
        width = doc.createElement('width')
        height = doc.createElement('height')
        depth = doc.createElement('depth')
        width.appendChild(width_text)
        height.appendChild(height_text)
        depth.appendChild(depth_text)
        size.appendChild(width)
        size.appendChild(height)
        size.appendChild(depth)
        annotation.appendChild(size)

        # ---------------segmented-------------
        segmented_text = doc.createTextNode(str(0))
        segmented = doc.createElement('segmented')
        segmented.appendChild(segmented_text)
        annotation.appendChild(segmented)

        # --------------object----------------
        for rect in face_rects:
            object = doc.createElement('object')
            name_text = doc.createTextNode('face')
            pose_text = doc.createTextNode('Left')
            truncated_text = doc.createTextNode(str(1))
            difficult_text = doc.createTextNode(str(0))
            name = doc.createElement('name')
            name.appendChild(name_text)
            pose = doc.createElement('pose')
            pose.appendChild(pose_text)
            truncated = doc.createElement('truncated')
            truncated.appendChild(truncated_text)
            difficult = doc.createElement('difficult')
            difficult.appendChild(difficult_text)
            object.appendChild(name)
            object.appendChild(pose)
            object.appendChild(truncated)
            object.appendChild(difficult)

            x_min, y_min, x_max, y_max = rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]
            bndbox = doc.createElement('bndbox')
            xmin_text = doc.createTextNode(str(x_min))
            ymin_text = doc.createTextNode(str(y_min))
            xmax_text = doc.createTextNode(str(x_max))
            ymax_text = doc.createTextNode(str(y_max))
            xmin = doc.createElement('xmin')
            ymin = doc.createElement('ymin')
            xmax = doc.createElement('xmax')
            ymax = doc.createElement('ymax')
            xmin.appendChild(xmin_text)
            ymin.appendChild(ymin_text)
            xmax.appendChild(xmax_text)
            ymax.appendChild(ymax_text)
            bndbox.appendChild(xmin)
            bndbox.appendChild(ymin)
            bndbox.appendChild(xmax)
            bndbox.appendChild(ymax)

            object.appendChild(bndbox)

            annotation.appendChild(object)

        return doc


if  __name__ == '__main__':
    # 原始定义的Wider Face 数据集
    original_dataset = dataset.WiderFaceDataset(
        images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
        ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt')

    voc_root = '/media/weipenghui/Extra/WiderFace/WiderFace_VOC'
    # 生成VOC目录
    folders = ['Annotations', 'ImageSets', 'JPEGImages', 'SegmentationClass', 'SegmentationObject']
    for i in folders:
        if not os.path.exists(os.path.join(voc_root, i)):
            os.mkdir(os.path.join(voc_root, i))
    os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Main'))
    os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Layout'))
    os.mkdir(os.path.join(voc_root + '/' + 'ImageSets', 'Segmentation'))
    train_txt = open(os.path.join(voc_root + '/' + 'ImageSets/Main', 'trainval.txt'), 'w')

    wfsample = WiderFaceSample()
    for i, sample in enumerate(original_dataset, 1):
        wfsample.image_name = sample['image_name']
        wfsample.face_rects = sample['label']
        wfsample.image_width, wfsample.image_height = sample['image'].size[0], sample['image'].size[1]
        # 写入图像
        wfsample.save_image(os.path.join(voc_root, 'JPEGImages'), str(i).zfill(6)+'.jpg')
        # 写入xml
        wfsample.save_label_to_xml(os.path.join(voc_root, 'Annotation'), new_xml_name=str(i).zfill(6)+'.xml', new_image_name=str(i).zfill(6)+'.jpg')
        # 写入txt
        train_txt.write(str(i).zfill(6) + '\n')
        print('Write: {}'.format(i))

最终的结果

  • 主目录


    11478104-da532b050c3d882e.png
    image.png
  • Annotations


    11478104-276b52a33a173ecc.png
    image.png
11478104-10853177395431b2.png
image.png
  • ImageSets
11478104-20d76e031a485cc6.png
image.png
11478104-153d4a86ef0ea20f.png
image.png
11478104-6b314e74213af20d.png
image.png
  • JPEGImages
11478104-b53c7935abf2c344.png
image.png

End

猜你喜欢

转载自blog.csdn.net/weixin_34116110/article/details/87130578