import os import cv2 import numpy as np import tensorflow as tf """ 将train文件夹下的cats和dog文件夹处理成train.tfrecords放在train文件夹里 """ #将图片的路径和对应的标签存储在list中返回 def deal(dir): images = [] temp = [] for root,dirs,files in os.walk(dir): for name in files: images.append(os.path.join(root,name)) for name in dirs:#以dogs cats文件夹形式读取 temp.append(os.path.join(root,name)) labels=[] for one_folder in temp: n_img=len(os.listdir(one_folder))#展开cats或者dogs的图片 letter=one_folder.split('/')[-1] if letter=='cats': labels=np.append(labels,n_img*[0])#np.append拼接 0是cat 1是dog else: labels=np.append(labels,n_img*[1]) #打乱 temp=np.array([images,labels]) temp=temp.transpose() np.random.shuffle(temp) image_list=list(temp[:,0]) label_list=list(temp[:,1]) label_list=[int(float(i)) for i in label_list] return image_list,label_list #返回整形特征 def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) #返回bytes特征 def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def convert_to_tfrecord(image_list,label_list): n_samples=len(label_list) writer = tf.python_io.TFRecordWriter('./data/train/train.tfrecords') print('start transform') for i in range(n_samples): try: img=cv2.imread(image_list[i]) img_raw = img.tobytes() # 图片转化成二进制 label=int(label_list[i]) example=tf.train.Example(features=tf.train.Features( feature={ 'label':int64_feature(label), 'image':bytes_feature(img_raw) })) writer.write(example.SerializeToString()) except: print(image_list[i]) os.remove(image_list[i]) writer.close() print('transform end') """ 下面是读取tfrecord和显示图片证明生成的tfrecord正确 """ filename='./data/train/train.tfrecords' #读取并解析.tfrecords文件 def read_and_decode(filename): filename_queue=tf.train.string_input_producer([filename])# 按队列的形式读取 reader=tf.TFRecordReader() _,serialized_example=reader.read(filename_queue)#返回文件名和文件 features=tf.parse_single_example(serialized_example, features={ 'label':tf.FixedLenFeature([],tf.int64),#与存储的类型一致 'image':tf.FixedLenFeature([],tf.string) }) img=tf.decode_raw(features['image'],tf.uint8) img=tf.reshape(img,shape=[227,227,3]) #img = tf.cast(img, dtype=tf.float32) * (1.0 / 128) - 0.5 label = tf.cast(features['label'], dtype=tf.int32) return img,label def show(): img,label=read_and_decode(filename) img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=1, capacity=11,min_after_dequeue=5) init=tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) threads=tf.train.start_queue_runners(sess=sess) for i in range(10): label=sess.run(label_batch) imgcv2=sess.run(img_batch) imgcv2.resize((227,227,3)) print(label) cv2.imshow('img',imgcv2) cv2.waitKey() if __name__ == '__main__': # image_list, label_list = deal('./data/train') # convert_to_tfrecord(image_list,label_list) show()
实现TFrecords文件的保存与读取
猜你喜欢
转载自blog.csdn.net/fanzonghao/article/details/80993237
今日推荐
周排行