代码:
import tensorflow as tf
import numpy as np
import cv2
img_width = 4
img_height = 3
img_channel = 3
img_path = "E:\\opencvPhoto\\photo\\count.png"
tfrecord_path = "F:\\tensorflow\\tfrecord\\test_tfrecord.tfrecords"
def create_img():
img = np.zeros((img_height, img_width, img_channel), dtype=np.uint8)
count = 0
for row in range(img_height):
for col in range(img_width):
for channel in range(img_channel):
if channel == 2:
# img[row, col, channel] = 255 # 生成红色图片
img[row, col, channel] = count
count += 1
print(img)
cv2.imwrite(img_path, img);
def create_TfRecord():
writer = tf.python_io.TFRecordWriter(tfrecord_path)
img = cv2.imread(img_path)
# opencv修改图片大小
img_resize = cv2.resize(img, (img_width, img_height))
# 将图片转化为原生bytes
img_raw = img_resize.tobytes()
# 生成example
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(0)])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
# 序列化字符串并写入
writer.write(example.SerializeToString())
def read_TfRecord(flie_list):
# 根据文件名生成一个队列
filename_queue = tf.train.string_input_producer(flie_list)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 解析TFRecords单个内容
features = tf.parse_single_example(serialized_example, features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
# 获得TFRecords中的img_raw并decode
img = tf.decode_raw(features['img_raw'], tf.uint8)
# Tensor进行reshape
img = tf.reshape(img, [img_height, img_width, img_channel])
# 转换成float32类型
img = tf.cast(img, tf.float32)
# 获得TFRecords中的label并转换成int64类型
label = tf.cast(features['label'], tf.int64)
# print(img, label)
return img, label
if __name__ == "__main__":
# 1.生成测试图片
# create_img()
# 2.生成TFRecord
# create_TfRecord()
# 3.读取TFrecord
read_img, read_label = read_TfRecord([tfrecord_path])
with tf.Session() as sess:
# 定义一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
img, label = sess.run([read_img, read_label])
print(img)
print(label)
# 回收子线程
coord.request_stop()
coord.join(threads)
输出:
# 1.生成测试图片输出
[[[ 0 0 0]
[ 0 0 1]
[ 0 0 2]
[ 0 0 3]]
[[ 0 0 4]
[ 0 0 5]
[ 0 0 6]
[ 0 0 7]]
[[ 0 0 8]
[ 0 0 9]
[ 0 0 10]
[ 0 0 11]]]
# 3.读取TFrecord输出
[[[ 0. 0. 0.]
[ 0. 0. 1.]
[ 0. 0. 2.]
[ 0. 0. 3.]]
[[ 0. 0. 4.]
[ 0. 0. 5.]
[ 0. 0. 6.]
[ 0. 0. 7.]]
[[ 0. 0. 8.]
[ 0. 0. 9.]
[ 0. 0. 10.]
[ 0. 0. 11.]]]
0