TFrecords类型数据集制作与读取(NWPU VHR-10数据集为例)
TFrecords数据类型是Tensorflow深度学习框架所支持的特定数据类型,其具有封装性好、读取方便、移动性好的特点,并且在TensorFlow中对于该类型的数据分批读取等操作具有大量内置函数,使用方便。故我们介绍一下如何将训练数据制作成TFrecords形式,以及如何读取TFrecords格式的数据。(以NWPU VHR-10为例)。
数据集预处理:
NWPU VHR-10数据集的组成如下图1所示,groundtruth中存储positive image set的目标包围框以及其类别标签,该数据集中的目标包围框标签如图2所示,每一行的小括号中分别是目标左上角与右下角的坐标信息。
我们已经知道了数据集的构成,接下来我们将其所有训练图像统一大小并将目标包围框标签制作为如图3所示形式,便于后面数据集的制作。(注意!!目标包围框标签也要随着图片大小的改变而缩放哦)
数据集预处理代码:
import os
from skimage import io
import utils
import config
import numpy as np
def convert(size, box):
xmin = config.img_w * box[0]/size[0]
ymin = config.img_h * box[1]/size[1]
xmax = config.img_w * box[2]/size[0]
ymax = config.img_h * box[3]/size[1]
return (int(xmin),int(ymin),int(xmax),int(ymax))
def convert_annotation(image_id):
out_file = open('./NWPU VHR-10 dataset/labels/%s.txt' % (image_id), 'w')
img = io.imread('./NWPU VHR-10 dataset/positive image set/%s.jpg' % (image_id))
out_image = utils.resize_image(img, config.img_w, config.img_h)
io.imsave('./NWPU VHR-10 dataset/traindata/{0}.jpg'.format(image_id), out_image)
w = np.shape(img)[0]
h = np.shape(img)[1]
with open('./NWPU VHR-10 dataset/ground truth/%s.txt' % (image_id)) as f:
in_file = f.read().splitlines()
for obj in in_file:
newmo = (((obj.replace('(', ' ')).replace(')', ' ')).replace('(', ' ')).replace(')', ' ')
list2 = newmo.strip().split(',')
# print(list2)
if len(list2) > 1:
cls_id = int(list2[4])
b = (float(list2[0]), float(list2[1]), float(list2[2]), float(list2[3]))
bb = convert((w, h), b)
out_file.write(str(bb[0]) + ' ' + str(bb[1]) + ' ' + str(bb[2]) + ' ' + str(bb[3]) + ' ' + str(cls_id) + '\n' )
else:
print(obj)
out_file.close()
data = [i for i in range(1, 651, 1)]
image_ids=[]
for image_temp in data:
if image_temp<100:
image_id_temp=(str(image_temp)).zfill(3)
else :
image_id_temp=str(image_temp)
image_ids.append(image_id_temp)
list_file = open('./NWPU VHR-10 dataset/train.txt','w')
for image_id in image_ids:
list_file.write('./NWPU VHR-10 dataset/positiveimageset/%s.jpg\n'%(image_id))
convert_annotation(image_id)
list_file.close()
TFrecords数据制作:
经过上一步骤的处理,我们已经得到了统一大小以后的训练图片集以及每张图片的类标文件(每一行代表一个目标的坐标及其类别信息),接下来我们进行TFrecords的制作,。
基于我的理解TFrecords的数据保存方式是以字典的方式,如下所示,每个训练样本的信息由该方式存储:
feature = tf.train.Features(feature={
'img_height': _int64_feature(img_height),
'img_width': _int64_feature(img_width),
'img': _bytes_feature(img.tobytes()),
'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
'num_objects': _int64_feature(gtbox_label.shape[0])
})
在TFrecords数据存储中,TFrecords支持int、float、byte三种类型数据,故我们先将图片矩阵与groundtruth类标转化为字符串形式,之后利用TensorFlow内置函数将其存储成二进制形式,将图片尺寸信息以及目标个数信息保存为int64形式,即可保存为TFrecords类型,完整代码如下所示:
import numpy as np
import tensorflow as tf
import glob
import cv2
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def read_gtbox_and_label(imname):
with open('./NWPU VHR-10 dataset/labels/' + imname + '.txt') as f:
boxlist = f.read().splitlines()
boxdata = []
for box in boxlist:
boxdata.append(box.split())
boxdata = np.array(boxdata, dtype=np.int32)
boxlist = np.reshape(boxdata, [-1, 5])
return boxlist
def convert_data_to_tfrecord():
save_path = "TFrecord.tfrecords"
data_path = './NWPU VHR-10 dataset/traindata'
writer = tf.python_io.TFRecordWriter(path=save_path)
for count, picture in enumerate(glob.glob(data_path + '/*.jpg')):
img_name = picture.split('.')[-2].split('\\')[1]
img_path = data_path + '/' + img_name + '.jpg'
img_height = 600
img_width = 800
gtbox_label = read_gtbox_and_label(img_name)
img = cv2.imread(img_path)
feature = tf.train.Features(feature={
'img_height': _int64_feature(img_height),
'img_width': _int64_feature(img_width),
'img': _bytes_feature(img.tobytes()),
'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
'num_objects': _int64_feature(gtbox_label.shape[0])
})
example = tf.train.Example(features=feature)
writer.write(example.SerializeToString())
print('End!')
if __name__ == '__main__':
convert_data_to_tfrecord()
TFrecords数据读取:
经过上述过程,我们的所有训练数据信息都存入了TFrecord.tfrecords文件中,可见其封装性非常好,只用一个文件即可表示所有的训练数据及其类别信息,接下来就是读取数据,下面代码即为读取代码,大体过程是利用TensorFlow内置函数生成训练文件队列,并且将保存的训练数据信息解码,最后再利用tf.train.batch函数一批批的生成训练数据,将其sess.run后就得到了我们所需的训练数据,即可输入网络进行训练。
import tensorflow as tf
import os
import cv2
def read_and_decode(filename):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'img_height': tf.FixedLenFeature([], tf.int64),
'img_width': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string),
'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
'num_objects': tf.FixedLenFeature([], tf.int64)
})
img_height = tf.cast(features['img_height'], tf.int32)
img_width = tf.cast(features['img_width'], tf.int32)
img = tf.decode_raw(features['img'], tf.uint8)
img = tf.reshape(img, shape=[img_height, img_width, 3])
img = tf.image.resize_images(img, [img_height, img_width])
gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 5])
num_objects = tf.cast(features['num_objects'], tf.int32)
return img, gtboxes_and_label, num_objects
def data_batch(filepath):
img, label, num_objects = read_and_decode(filepath)
return img, label, num_objects
if __name__ == "__main__":
filepath ="TFrecord.tfrecords"
with tf.Session() as sess:
img_batch, label_batch, _ = data_batch(filepath)
img_batch, label_batch, _ = tf.train.batch([img_batch, label_batch, _], batch_size=8, capacity=2000, dynamic_pad=True)
sess.run(tf.global_variables_initializer())
threads = tf.train.start_queue_runners(sess=sess)
img, label, _= sess.run([img_batch, label_batch, _])
cv2.imwrite('./haha.jpg', img[1, :])