前言
TFRecord 这部分内容困扰了我好几天, 不仅是它的 API 十分晦涩且繁琐, 而且网上的大多数相关教程写的都很抽象, 在看了相关的教材之后才终于有了眉目.
TFRecord 的意义在于: 如果你要训练上万张图片, 他们全部塞进内存里可能需要占用数十甚至上百 GB 的空间, 这时候传统的 feed_dict 方式就不能用了, 需要使用 TFRecord 方式, 即将所有数据转换为 TFRecord 格式的二进制文件, 通过调用 TensorFlow 的相关 API 实现高速的顺序读取, 这样可以在内存有限的情况下完成大体积数据集的训练.
Save
首先是将数据转化为 TFRecord 的代码, 它的 API 真的很晦涩…
它将数据集划分为一个一个的 Example, 比如在图片识别的场景下, 一个 Example 包含一张图片的数据和这张图片对应的标签. 然后通过 tf.python_io.TFRecordWriter 将其写入到指定文件中
代码样例
import tensorflow as tf
import numpy as np
def int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def create_example(x, y):
return tf.train.Example(features = tf.train.Features(feature = {
"x": int64_feature(x),
"y": int64_feature(y)
}))
def save_record(x, y, index):
file_name = './data_%d.record' % index
with tf.python_io.TFRecordWriter(file_name) as writer:
for i, (_x, _y) in enumerate(zip(x, y)):
writer.write(create_example(_x, _y).SerializeToString())
def main(_):
# 这里我生成了 1000 组数据, 每组数据包含一个 x 和一个 y
x = np.arange(1000).astype(np.int64)
y = np.arange(1000).astype(np.int64)
# 将这些数据写入到 10 个文件中
for i in range(10):
save_record(x, y, i)
if __name__ == '__main__':
tf.app.run()
Load
从 TFRecord 文件读取的逻辑是:
- 通过 tf.TFRecordReader 读取文件
- 每次读取一个 example (二进制格式), 然后通过调用 tf.parse_single_example() 将该 example 解码为原本的数据, 从而获得了数据
- 如果你还想要获取一个 batch, 通过 tf.train.shuffle_batch 获取一个局部随机的 batch, 它的逻辑是, 如果你想要一组 16 个数据的 batch, 它会先读比如 2000 个 example, 然后从这 2000 个里面随机给你 16 个数据作为一个 batch, 因此如果你的数据原本是有序的, 那么通过这个方法得到的"随机" batch 其实是一个局部的随机, 我的办法是在将数据转化为 TFRecord 文件之前就打乱其顺序.
代码样例
import tensorflow as tf
import numpy as np
import os
cwd = os.getcwd()
paths = []
for i in range(10):
paths.append(os.path.join(cwd, 'data_%d.record' % i))
filename_queue = tf.train.string_input_producer(paths, num_epochs = 1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = {
'x': tf.FixedLenFeature([], tf.int64),
'y': tf.FixedLenFeature([], tf.int64)
})
batchs = tf.train.shuffle_batch([x, y], batch_size = 16, capacity = 100, min_after_dequeue=50)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
try:
while not coord.should_stop():
print(sess.run(batchs))
except tf.errors.OutOfRangeError:
print('done')
finally:
coord.request_stop()
coord.join(threads)
总结
这部分内容真的把我搞了, 为什么能设计的这么复杂?
我用了两三天的时间才接受了它的设定, 这离我的目标 ( Fine-tuning pre-trained model with large dataset) 更近了一步.