从图片文件夹随机生成训练集和测试集

随机生成训练集和测试集
代码在图片目录上一层目录运行,设图片目录中每张图片都有xml或txt格式的标记文件。

# -*- coding: UTF-8 -*-
# import xml.etree.ElementTree as ET
import os
import random


if __name__ == "__main__":
    wd = os.getcwd()  # 获取当前文件目录
    dir_list = os.listdir(wd)
    dir_pic = []
    for item in dir_list:
        if os.path.isdir(item):
            # dir.append(item)
            dir_path = os.path.join(wd, item)
            dir_list_temp = os.listdir(dir_path)
            for item2 in dir_list_temp:
                if os.path.isfile(dir_path.replace('\\', '/') + '/' + item2):
                    try:
                        # print(item2.split('.')[0])
                        if item2.split('.')[1] == 'jpg' or item2.split('.')[1] == 'JPG':
                            dir_pic.append(dir_path + '/' + item2)  # 此处搜集完所有xml文件路径名称放在dir中
                    except IndexError:
                        pass
                        # print('IndexError, 可能遇到了不带后缀名的文件')
    # ************************************************************
    # 生成随机数用于随机分配训练集和测试集
    probe = random.randint(1, 100)

    print("Probability: %d" % probe)

    train_file = open('train_file.txt', 'w')
    test_file = open('test_file.txt', 'w')

    for item in dir_pic:
        item = item.replace('\\', '/')
        probe = random.randint(1, 100)
        if probe < 75:
            train_file.write(item + '\n')
        else:
            test_file.write(item + '\n')

    train_file.close()
    test_file.close()
    
    nameFile_dir = wd.replace('\\', '/') + '/cfg/taco.names'
    count = len(open(nameFile_dir , 'r').readlines())

    with open('cfg/taco.data', 'w') as f:
        f.write('classes=' + str(count) + '\n')
        f.write('train = ' + wd.replace('\\', '/') + '/train_file.txt' + '\n')
        f.write('valid = ' + wd.replace('\\', '/') + '/test_file.txt' + '\n')
        f.write('names = ' + wd.replace('\\', '/') + '/cfg/taco.names' + '\n')
        f.write('backup = ' + wd.replace('\\', '/') + '/backup'+'\n')

    print(wd.replace('\\', '/')  + '/cfg/taco.data')
    print(wd.replace('\\', '/') + '/cfg/yolov4-tiny.cfg')
    print(wd.replace('\\', '/') + '/pre_trained/yolov4-tiny.conv.29')
    print(wd.replace('\\', '/') + '/backup')


猜你喜欢

转载自blog.csdn.net/ohhardtoname/article/details/115299636