版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011956147/article/details/79290163
TFRecords是tensorflow内定标准格式,类似caffe的lmdb,好处是做好数据后使用方便,速度快,但是占用磁盘空间,在很多情况下,直接读取数据处理也是一种比较好的方法,比如做检测,以后会补充多线程直接进行数据读取。这里记录下使用TFRecords的方法,仅作为个人备忘录。
1,生成TFRecords
使用tf.train.Example把、定义要写入protobuf buffer数据的格式,使用tf.python_io.TFRecordWriter写入,更多的关于tf.train.Example参考:https://www.tensorflow.org/api_docs/python/tf/train/Example
def _convert_to_example_simple(image_example, image_buffer):
"""
covert to tfrecord file
:param image_example: dict, an image example
:param image_buffer: string, JPEG encoding of RGB image
:param colorspace:
:param channels:
:param image_format:
:return:
Example proto
"""
class_label = image_example['label']
bbox = image_example['bbox']
roi = [bbox['xmin'],bbox['ymin'],bbox['xmax'],bbox['ymax']]
landmark = [bbox['xlefteye'],bbox['ylefteye'],bbox['xrighteye'],bbox['yrighteye'],bbox['xnose'],bbox['ynose'],
bbox['xleftmouth'],bbox['yleftmouth'],bbox['xrightmouth'],bbox['yrightmouth']]
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_feature(image_buffer),
'image/label': _int64_feature(class_label),
'image/roi': _float_feature(roi),
'image/landmark': _float_feature(landmark)
}))
return example
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
example = _convert_to_example_simple(image_example, image_data)
tfrecord_writer.write(example.SerializeToString())
这里image_example和image_data都需要事先准备好
2、数据解析
def read_tfrecord(tfrecord_file, batch_size):
filename_queue = tf.train.string_input_producer([tfrecord_file],shuffle=True)
# read tfrecord
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
image_features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/label': tf.FixedLenFeature([], tf.int64),
'image/roi': tf.FixedLenFeature([4], tf.float32),
'image/landmark': tf.FixedLenFeature([10],tf.float32)
}
)
image = tf.decode_raw(image_features['image/encoded'], tf.uint8)
image = tf.reshape(image, [image_size, image_size, 3]) # 一些预处理
image = (tf.cast(image, tf.float32)-127.5) / 128
label = tf.cast(image_features['image/label'], tf.float32)
roi = tf.cast(image_features['image/roi'],tf.float32)
landmark = tf.cast(image_features['image/landmark'],tf.float32)
image, label,roi,landmark = tf.train.batch(
[image, label,roi,landmark],
batch_size=batch_size,
num_threads=2,
capacity=1 * batch_size
)
label = tf.reshape(label, [batch_size])
roi = tf.reshape(roi,[batch_size,4])
landmark = tf.reshape(landmark,[batch_size,10])
return image, label, roi,landmark
3、训练使用
'''
other code
'''
image_batch, label_batch, bbox_batch,landmark_batch = read_tfrecord(dataset_dir, BATCH_SIZE)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#begin
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(MAX_STEP):
step = step + 1
if coord.should_stop():
break
image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
'''
other code
'''
except tf.errors.OutOfRangeError:
print("Success!")
finally:
coord.request_stop()
coord.join(threads)
4、注意事项
- tensorflow里面都是operator和tensor,需要sess.run()才能使用
- TFRecordReader会一直弹出队列中文件的名字,直到队列为空
- 使用前需要先初始化graph:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
原文链接:http://blog.csdn.net/u011956147/article/details/79290163