tfrecord 的创建和读入
创建
import os
import tensorflow as tf
import numpy as np
import random
import cv2
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def get_image(Path):
img = plt.imread(Path)
img = cv2.resize(img, (224, 224), interpolation = cv2.INTER_AREA)
img = img.astype(np.float32)
img = img/255 - 0.5
img = img.tostring()
return img
Train_data_list = os.listdir("/home/winsoul/disk/dogsVScats/data/train/")
random.shuffle(Train_data_list)
TrainTfrecordPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/train.tfrecords'
total = len(Train_data_list)
with tf.python_io.TFRecordWriter(TrainTfrecordPath) as writer:
j = 0
for Date_name in Train_data_list[0:22000]:
label = Date_name.split('.')[0]
label = 0 if (label == 'dog') else 1
img = get_image('/home/winsoul/disk/dogsVScats/data/train/' + Date_name);
feature = {
'label':_int64_feature(label),
'image':_bytes_feature(img),
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
j = j + 1
if j % 5000 == 0:
print(j)
TestTfrecordPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/test.tfrecords'
with tf.python_io.TFRecordWriter(TestTfrecordPath) as writer:
for Date_name in Train_data_list[22000:total]:
label = Date_name.split('.')[0]
label = 0 if (label == 'dog') else 1
img = get_image('/home/winsoul/disk/dogsVScats/data/train/' + Date_name);
feature = {
'label':_int64_feature(label),
'image':_bytes_feature(img),
}
example = tf.train.Example(features = tf.train.Features(feature = feature))
writer.write(example.SerializeToString())
j = j + 1
if j % 5000 == 0:
print(j)
读入
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
import tensorflow as tf
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
from PIL import Image
import multiprocessing
TrainPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/train.tfrecords'
TestPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/test.tfrecords'
epoch = 10
DisplayStep = 20
SaveModelStep = 1000
def read_tfrecord(TFRecordPath):
with tf.Session() as sess:
feature = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
filename_queue = tf.train.string_input_producer([TFRecordPath])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = feature)
image = tf.decode_raw(features['image'], tf.float32)
image = tf.reshape(image, [224, 224, 3])
label = tf.cast(features['label'], tf.int32)
return image, label
def Network(BatchSize, learning_rate):
with tf.Session() as sess:
image, label = read_tfrecord(TrainPath)
image_Batch, label_Batch = tf.train.shuffle_batch([image, label],
batch_size = BatchSize,
capacity = BatchSize*3 + 200,
min_after_dequeue = BatchSize)
label_Batch = tf.one_hot(label_Batch, depth = 2)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
try:
while not coord.should_stop():
img, lbl = sess.run([image_Batch, label_Batch])
img = (img + 0.5) * 255
img = img.astype(np.uint8)
for j in range(10):
plt.imshow(img[j, ...])
plt.title(lbl[j])
plt.show()
except tf.errors.OutOfRangeError:
print("OutofRangeError!")
coord.request_stop()
coord.join(threads)
sess.close()
def main():
Network(16, 0.1)
if __name__ == '__main__':
main()