1.准备数据集
使用labelimg将数据集中需要识别的部位框出来
2.划分数据集,训练集
编写代码,自动划分,以及将VOC格式转为YOLO格式
import xml.etree.ElementTree as ET
import os
import random
from shutil import copyfile
import cv2
classes = ["gas", "fire"]
TRAIN_RATIO = 0.8
def clear_hidden_files(path):
"""
清除文件夹下的隐藏文件
"""
# 获取path路径下的所有文件
dir_list = os.listdir(path)
for i in dir_list:
# 获取每个文件的绝对路径
abspath = os.path.join(os.path.abspath(path), i)
if os.path.isfile(abspath):
if i.startswith("._"):
os.remove(abspath)
else:
clear_hidden_files(abspath)
def convert(size, box):
"""
将数据归一化处理
"""
dw = 1. / size[0]
dh = 1. / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def convert_annotation(label_id):
"""
将xml文件转换为yolo格式
"""
in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' % label_id)
out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' % label_id, 'w')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
in_file.close()
out_file.close()
def vocToTxt():
"""
将VOC数据转为YOLO数据
:return:
"""
# 获取当前工作目录
wd = os.getcwd()
data_base_dir = os.path.join(wd, "VOCdevkit/")
if not os.path.isdir(data_base_dir):
os.mkdir(data_base_dir)
work_space_dir = os.path.join(data_base_dir, "VOC2007/")
if not os.path.isdir(work_space_dir):
os.mkdir(work_space_dir)
annotation_dir = os.path.join(work_space_dir, "Annotations/")
if not os.path.isdir(annotation_dir):
os.mkdir(annotation_dir)
clear_hidden_files(annotation_dir)
image_dir = os.path.join(work_space_dir, "JPEGImages/")
if not os.path.isdir(image_dir):
os.mkdir(image_dir)
clear_hidden_files(image_dir)
yolo_labels_dir = os.path.join(work_space_dir, "YOLOLabels/")
if not os.path.isdir(yolo_labels_dir):
os.mkdir(yolo_labels_dir)
clear_hidden_files(yolo_labels_dir)
yolov5_images_dir = os.path.join(data_base_dir, "images/")
if not os.path.isdir(yolov5_images_dir):
os.mkdir(yolov5_images_dir)
clear_hidden_files(yolov5_images_dir)
yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
if not os.path.isdir(yolov5_labels_dir):
os.mkdir(yolov5_labels_dir)
clear_hidden_files(yolov5_labels_dir)
yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
if not os.path.isdir(yolov5_images_train_dir):
os.mkdir(yolov5_images_train_dir)
clear_hidden_files(yolov5_images_train_dir)
yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
if not os.path.isdir(yolov5_images_test_dir):
os.mkdir(yolov5_images_test_dir)
clear_hidden_files(yolov5_images_test_dir)
yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
if not os.path.isdir(yolov5_labels_train_dir):
os.mkdir(yolov5_labels_train_dir)
clear_hidden_files(yolov5_labels_train_dir)
yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
if not os.path.isdir(yolov5_labels_test_dir):
os.mkdir(yolov5_labels_test_dir)
clear_hidden_files(yolov5_labels_test_dir)
list_imgs = os.listdir(image_dir) # list image files
list_xml = os.listdir(annotation_dir)
# 将数据集重新排序
prob = random.sample(range(0, len(list_imgs)), len(list_imgs))
# 将数据集分为训练集和验证集
for i in range(0, len(list_imgs)):
# 获取图片数据集中每个图片的路径
path = os.path.join(image_dir, list_imgs[i])
# 获得每个图片文件的路径,带文件名
image_path = image_dir + list_imgs[i]
img_name = list_imgs[i]
# 获得每个xml文件的路径,带文件名
annotation_path = annotation_dir + list_xml[i]
# 分离文件名与扩展名 extension:扩展名
(nameWithoutExtension, extension) = os.path.splitext(os.path.basename(image_path))
# 使用图片名命名标签文件名,并将文件名后缀改为txt,放入YOLOLabels文件夹中
label_name = nameWithoutExtension + '.txt'
label_path = os.path.join(yolo_labels_dir, label_name)
print("Pro: %d" % prob[i])
if prob[i] < (len(list_imgs) * TRAIN_RATIO): # train dataset
if os.path.exists(annotation_path):
convert_annotation(nameWithoutExtension) # convert label
copyfile(image_path, yolov5_images_train_dir + img_name)
copyfile(label_path, yolov5_labels_train_dir + label_name)
else: # test dataset
if os.path.exists(annotation_path):
convert_annotation(nameWithoutExtension) # convert label
copyfile(image_path, yolov5_images_test_dir + img_name)
copyfile(label_path, yolov5_labels_test_dir + label_name)
3.训练模型
首先根据voc.yaml,制作一个自己数据集的yaml文件,把path改为自己的数据集路径,nc改为自己数据集的类别数,name改为自己的类别名称
修改train.py文件(这里以yolov5s为例,也可以使用其他模型)
weights:default改为yolov5s.pt(如果不想使用预训练权重的话,可以输入none)
cfg:default改为yolov5s.yaml(选择模型文件)
data:data为刚才数据集yaml文件的储存路径
epoch:训练的轮次,自己设置
batch_size:根据自己的内存选择,16, 32, 64都可以,如果报错的话就改小一点
接下来运行train.py文件就可以了