随机生成训练集和测试集
代码在图片目录上一层目录运行,设图片目录中每张图片都有xml或txt格式的标记文件。
# -*- coding: UTF-8 -*-
# import xml.etree.ElementTree as ET
import os
import random
if __name__ == "__main__":
wd = os.getcwd() # 获取当前文件目录
dir_list = os.listdir(wd)
dir_pic = []
for item in dir_list:
if os.path.isdir(item):
# dir.append(item)
dir_path = os.path.join(wd, item)
dir_list_temp = os.listdir(dir_path)
for item2 in dir_list_temp:
if os.path.isfile(dir_path.replace('\\', '/') + '/' + item2):
try:
# print(item2.split('.')[0])
if item2.split('.')[1] == 'jpg' or item2.split('.')[1] == 'JPG':
dir_pic.append(dir_path + '/' + item2) # 此处搜集完所有xml文件路径名称放在dir中
except IndexError:
pass
# print('IndexError, 可能遇到了不带后缀名的文件')
# ************************************************************
# 生成随机数用于随机分配训练集和测试集
probe = random.randint(1, 100)
print("Probability: %d" % probe)
train_file = open('train_file.txt', 'w')
test_file = open('test_file.txt', 'w')
for item in dir_pic:
item = item.replace('\\', '/')
probe = random.randint(1, 100)
if probe < 75:
train_file.write(item + '\n')
else:
test_file.write(item + '\n')
train_file.close()
test_file.close()
nameFile_dir = wd.replace('\\', '/') + '/cfg/taco.names'
count = len(open(nameFile_dir , 'r').readlines())
with open('cfg/taco.data', 'w') as f:
f.write('classes=' + str(count) + '\n')
f.write('train = ' + wd.replace('\\', '/') + '/train_file.txt' + '\n')
f.write('valid = ' + wd.replace('\\', '/') + '/test_file.txt' + '\n')
f.write('names = ' + wd.replace('\\', '/') + '/cfg/taco.names' + '\n')
f.write('backup = ' + wd.replace('\\', '/') + '/backup'+'\n')
print(wd.replace('\\', '/') + '/cfg/taco.data')
print(wd.replace('\\', '/') + '/cfg/yolov4-tiny.cfg')
print(wd.replace('\\', '/') + '/pre_trained/yolov4-tiny.conv.29')
print(wd.replace('\\', '/') + '/backup')