【深度学习】tfrecord 的创建和读入

tfrecord 的创建和读入

创建

#!/usr/bin/env python
# coding: utf-8

# In[2]:


import os
import tensorflow as tf
import numpy as np
import random
import cv2
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt 


# In[3]:


def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# In[4]:


def get_image(Path):
    img = plt.imread(Path)
    img = cv2.resize(img, (224, 224), interpolation = cv2.INTER_AREA)
#     plt.imshow(img)
#     plt.show()
#     if img.shape != (224, 224, 3):
#         print(img.shape)
    img = img.astype(np.float32)
    img = img/255 - 0.5
    img = img.tostring()
    return img


# In[5]:


Train_data_list = os.listdir("/home/winsoul/disk/dogsVScats/data/train/")
# Train_data_list.sort()
random.shuffle(Train_data_list)


# In[6]:


#  for Date_name in Train_data_list:
#         label = Date_name.split('.')[0]
#         label = 0 if (label == 'dog') else 1
#         img = get_image('/home/winsoul/disk/dogsVScats/data/train/' + Date_name);


# In[7]:


TrainTfrecordPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/train.tfrecords'
total = len(Train_data_list)
with tf.python_io.TFRecordWriter(TrainTfrecordPath) as writer:
    j = 0
    for Date_name in Train_data_list[0:22000]:
        label = Date_name.split('.')[0]
        label = 0 if (label == 'dog') else 1
        img = get_image('/home/winsoul/disk/dogsVScats/data/train/' + Date_name);
        feature = {
            'label':_int64_feature(label),
            'image':_bytes_feature(img),
        }
        example = tf.train.Example(features = tf.train.Features(feature = feature))
        writer.write(example.SerializeToString())
        j = j + 1
        if j % 5000 == 0:
            print(j)


# In[9]:


TestTfrecordPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/test.tfrecords'
with tf.python_io.TFRecordWriter(TestTfrecordPath) as writer:
    for Date_name in Train_data_list[22000:total]:
        label = Date_name.split('.')[0]
        label = 0 if (label == 'dog') else 1
        img = get_image('/home/winsoul/disk/dogsVScats/data/train/' + Date_name);
        feature = {
            'label':_int64_feature(label),
            'image':_bytes_feature(img),
        }
        example = tf.train.Example(features = tf.train.Features(feature = feature))
        writer.write(example.SerializeToString())
        j = j + 1
        if j % 5000 == 0:
            print(j)



读入

#!/usr/bin/env python
# coding: utf-8

# In[19]:


import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
import tensorflow as tf
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt 
from PIL import Image
import multiprocessing


# In[20]:



TrainPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/train.tfrecords'
TestPath = '/home/winsoul/disk/dogsVScats/data/tfrecord/test.tfrecords'
# BatchSize = 64
epoch = 10
DisplayStep = 20
SaveModelStep = 1000


# In[21]:


def read_tfrecord(TFRecordPath):
    with tf.Session() as sess:
        feature = {
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }
#         filename_queue = tf.train.string_input_producer([TFRecordPath], num_epochs = 1)
        filename_queue = tf.train.string_input_producer([TFRecordPath])
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(serialized_example, features = feature)
        image = tf.decode_raw(features['image'], tf.float32)
        image = tf.reshape(image, [224, 224, 3])
        label = tf.cast(features['label'], tf.int32)
        return image, label


# In[22]:


def Network(BatchSize, learning_rate):
    with tf.Session() as sess:
        image, label = read_tfrecord(TrainPath)
        image_Batch, label_Batch = tf.train.shuffle_batch([image, label], 
                                                     batch_size = BatchSize, 
                                                     capacity = BatchSize*3 + 200,
                                                     min_after_dequeue = BatchSize)
        label_Batch = tf.one_hot(label_Batch, depth = 2)
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord = coord)
        
        try:
            while not coord.should_stop():
                img, lbl = sess.run([image_Batch, label_Batch])
                img = (img + 0.5) * 255
                img = img.astype(np.uint8)
                for j in range(10):
                    plt.imshow(img[j, ...])
                    plt.title(lbl[j])
                    plt.show()
        except tf.errors.OutOfRangeError:
            print("OutofRangeError!")
    
        coord.request_stop()
        coord.join(threads)
        sess.close()


# In[23]:


def main():
    Network(16, 0.1)


# In[24]:


if __name__ == '__main__':
    main()


发布了79 篇原创文章 · 获赞 56 · 访问量 50万+

猜你喜欢

转载自blog.csdn.net/qq_40861916/article/details/99684567