一、生成TFRecord数据
首先定义两个函数将features打包成example proto。
import tensorflow as tf
def _int64_feature(value):
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
创建writer用来写入tfrecord数据。
writer = tf.python_io.TFRecordWriter('test.tfrecord')
创建example,写入到writer中。
for i in range(0, 2):
a = 5 + i
b = 2 * (i + 1)
print("i:", i, a)
print("i:", i, b)
example = tf.train.Example(
features=tf.train.Features(
feature={'a': _int64_feature(a),
'b': _float_feature(b)}))
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
输出:
i: 0 5
i: 0 2
i: 1 6
i: 1 4
这里循环两次生成两个example写入到writer中。
在当前路径下生成了一个test.tfrecord文件。
二、读取TFRecord文件
下面来看一下如何读取tfrecord文件。
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], dtype=tf.int64),
'b': tf.FixedLenFeature([], dtype=tf.float32)
})
现在可以从features里面去数据了。
a_out = features['a']
b_out = features['b']
print(a_out)
print(b_out)
输出:
Tensor("ParseSingleExample/ParseSingleExample:0", shape=(), dtype=int64)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=float32)
在Session run的时候以batch的方式读取数据。batch_size=1,每次只取一个数据。
扫描二维码关注公众号,回复:
2234956 查看本文章
a_batch, b_batch = tf.train.shuffle_batch([a_out, b_out], batch_size=1, capacity=100, min_after_dequeue=50, num_threads=1)
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
a_val, b_val = sess.run([a_batch, b_batch])
print('first run:')
print(a_val)
print(b_val)
a_val, b_val = sess.run([a_batch, b_batch])
print('second run:')
print(a_val)
print(b_val)
输出:
first run:
[6]
[4.]
second run:
[5]
[2.]
可以看到,虽然我们存储的时候是先存的5,2,后存的6, 4,但是因为我们使用了tf.train.shuffle_batch随机乱序的读取数据,所以第一次运行的时候取到了6,4。