Python: 自动将Faster数据集分随机为trainval、test、val、train

2018.10.20

trainval_num与train_num为比例。

set(a).difference(set(b)) 求差集

random.shuffle()打乱数据,有返回。

math.floor向下取整。

random.sample(a,b):在a中随机采样b个元素。

os.listdir()生成文件夹下文件名称列表。

"""
https://blog.csdn.net/gusui7202/article/details/83239142
qhy。
"""
# -*- coding: utf -*-
import os
import random
import math
fo1=open('H:/qhy_database/Dataset_009/VOC2007/ImageSets/Main/test.txt','w')
fo2=open('H:/qhy_database/Dataset_009/VOC2007/ImageSets/Main/trainval.txt','w')
fo3=open('H:/qhy_database/Dataset_009/VOC2007/ImageSets/Main/train.txt','w')
fo4=open('H:/qhy_database/Dataset_009/VOC2007/ImageSets/Main/val.txt','w')
filepath='H:/qhy_database/Dataset_009/VOC2007/Annotations'
filelist=os.listdir(filepath)
file_name=[]
for w in filelist:
    file_name.append(w.replace('.xml',''))
file_num=len(file_name)
trainval_num=0.8
train_num=0.8
#trainval
trainval_list=random.sample(range(file_num),math.floor(trainval_num*file_num))
#test
test_list=(list(set(range(file_num)).difference(set(trainval_list))))
random.shuffle(test_list)
#train
train_list=random.sample(trainval_list,math.floor(train_num*len(trainval_list)))
#val
val_list=list(set(trainval_list).difference(set(train_list)))
random.shuffle(val_list)
#put in txt
for i in trainval_list:
    fo2.write(file_name[i]+'\n')
for i in test_list:
    fo1.write(file_name[i]+'\n')
for i in train_list:
    fo3.write(file_name[i]+'\n')
for i in val_list:
    fo4.write(file_name[i]+'\n')
fo1.close()
fo2.close()
fo3.close()
fo4.close()

猜你喜欢

转载自blog.csdn.net/gusui7202/article/details/83215632