TFRecords是tensorflow官网提供的一种二进制文件,它能方便的进行数据复制 移动 和更好的利用内存,同时不需要单独的标签文件(在读取数据文件是自动添加标签,下面有介绍);在训练时,使用TFRecords中数据的流程:首先生成xxx.tfrecord文件,接着使用input pipeline读取xxx.tfrecords文件/其他支持格式,then读取并解码数据,随机乱序(shuffle),生成文件序列(batch);最后输入到模型中。
如果有一串jpg图片地址和相应的标签:Images
和 Labels
1. 生成TFRecords
存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64;具体来说,TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = "E:/Anaconda3/tensorflow/Dataset/data/"
classes = {'cats', 'dogs'} #预先自己定义的类别
#将数据转化TFRecord文件对应的属性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
# 开始将数据写入TFRecord文件(xxx.tfrecord)
train_filename = 'tensorflow/train.tfrecords' # 输出文件地址
# 创建一个writer来写TFRecords文件(写TFRecords <==> 输出TFRecords文件)
writer = tf.python_io.TFRecordWriter(train_filename) #输出成tfrecord文件
for index, name in enumerate(classes): # 从classes中自动获取类别 (label)
class_path = cwd + name + '//'
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每张picture的绝对地址
img = Image.open(img_path)
img = img.resize((640, 320))
img_raw = img.tobytes() #将图片转化为二进制格式
# 创建一个属性(feature)
example = tf.train.Example(features = tf.train.Features(feature = {
"label":_int64_feature(index),
"img_raw":_bytes_feature(img_raw),
}))
# 将上面的example protocol buffer 写入文件
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
输入: 数据文件路径 path
输出: xxx.tfrecords文件
2. 读取TFRecord 文件
(1). 用tf.train.string_input_producer 读取tfrecords文件(xxx.tfrecords)的list建立文件名队列(FIFO序列),同时,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱;结果:图像路径list
(2). 定义TFRecordReader读取(1)中的序列(图像路径list)返回下一个record;结果:serialize example和feature字典
(3). 用tf.parse_string_example对读取的TFRecords文件进行解码,抽取((2) serialize example和feature字典)中,返回feature对应的值,此时对应的值都是string,需要经过tf.decode(...) 和 tf.cast(...)等操作,将string类型的图像数据还原原始图像;同时也可以进行一些preprocessing操作;
(4). 利用tf.train.shuffle_batch(...)和tf.train.batch(...)将(3)中还原原始图像生成batch图像序列
#读取文件
def read_and_decode(filename,batch_size):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [300, 300, 3]) #图像归一化大小
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #图像减去均值处理,根据自己的需要决定要不要加上
label = tf.cast(features['label'], tf.int32)
#特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=200,
min_after_dequeue=150)
return img_batch, tf.reshape(label_batch,[batch_size])
在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量;
输入:XXX.tfrecords batch_size
输出: image_batch label_batch
3. 扩展
由于tf.train()函数在graph中增加了tf.train.QueueRunner类(在线程中运行线程中的队列数据),tf.train.start_queue_runner启动所有graph中的线程;用tf.train.Coordinator来管理线程(启动多少线程 何时终止线程...)
# initialize global & local variables
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
# create a coordinate and run queue runner objects
# 启动多线程处理数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(3):
batch_images, batch_labels = sess.run([images, labels])
for i in range(10):
plt.imshow(batch_images[i, ...])
plt.show()
print "Current image label is: ", batch_lables[i]
# close threads 结束线程
coord.request_stop()
coord.join(threads)
sess.close()
4. 如何显示xxx.tfrecords文件中的图片
tfrecords_file = 'E:/Anaconda3/tensorflow//dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file,Batch_size)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
# just plot one batch size
image, label = sess.run([image_batch, label_batch])
for j in np.arange(4):
print('label: %d' % label[j])
plt.imshow(image[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
batch_size这里可以大家任意设定,显示几幅图片都可以,这里设置为6 同时i 控制显示张数
5. 完整代码
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = "E:/Anaconda3/tensorflow/dataset/data/"
classes = {'cats', 'dogs'}
writer = tf.python_io.TFRecordWriter('train.tfrecords')
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
for index, name in enumerate(classes):
class_path = cwd + name + '//'
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每张图片的绝对地址
img = Image.open(img_path)
img = img.resize((640, 320))
img_raw = img.tobytes() #将图片转化为二进制格式
example = tf.train.Example(features = tf.train.Features(feature = {
"label":_int64_feature(index),
"img_raw":_bytes_feature(img_raw),
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
def read_and_decode(filename, batch_size): # read train.tfrecords
filename_queue = tf.train.string_input_producer([filename])# create a queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' :tf.FixedLenFeature([],tf.string),
})#return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [208, 208, 3]) #reshape image to 512*80*3
label = tf.cast(features['label'], tf.int32) #throw label tensor
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=2000,
min_after_dequeue=1500,
)
return img_batch, tf.reshape(label_batch,[batch_size])
tfrecords_file = 'D:/Anaconda3/tensorflow/dataset/train.tfrecords'
Batch_size = 6
image_batch, label_batch = read_and_decode(tfrecords_file, Batch_size)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
# just plot one batch size
image, label = sess.run([image_batch, label_batch])
for j in np.arange(BATCH_SIZE):
print('label: %d' % label[j])
plt.imshow(image[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
6. 参考文献
1. https://blog.csdn.net/u012222949/article/details/72875281 有imageFile 和 labelFile, 将imageFile和 labelFile分成train_set test_set
2. https://blog.csdn.net/wiinter_fdd/article/details/72835939 imageFile_train + class{} 类别自动生成 + imageFile_test
3. https://blog.csdn.net/gybheroin/article/details/79800679 同上
4. http://www.cnblogs.com/arkenstone/p/7507261.html 结构特别清晰
5. https://www.cnblogs.com/Charles-Wan/p/6197019.html 读取数据分类清晰