在cifar10_input里面用到了多线程处理和图像的预处理,刚好和第七章的内容符合,所以我就把这个函数从前到后理了一遍。
1、先获取数据文件的列表,源码利用的是for循环构建一个文件列表,也可以用TF提供的函数,
2、判断文件是否存在,如果不存在,直接抛出错误
3、创建一个文件队列,然后从队列中读取文件内容
4、读取文件就用到了读取文件的函数,在函数里面我们把数据处理好,直接输出结果
5、对得到的图像进行翻转、色彩调整等操作,这一步有改进的空间,然后进行归一化
6、对像素矩阵调整shape,定义好参数,一个batch一个batch的输出
7、用tf.train.batch函数,可以多线程的输出一个batch的图像数据
import tensorflow as tf import os #定义超参 IMAGE_SIZE = 24 NUM_CLASS = 10 NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 #从文件列表中的文件中读取样例 def read_image(filename_queue): #读取并解析图片 #创建一个伪类 并初始化 class CIFAR10Record(object): pass result = CIFAR10Record() label_bytes = 1 result.height = 32 result.width = 32 result.depth = 3 image_bytes = result.width * result.height * result.depth #每一个记录都是图像+标签 record_bytes = label_bytes + image_bytes #定义一个reader,按照长度大小读取 reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key, value = reader.read(filename_queue) #解码 , record_bytes = tf.decode_raw(value, tf.uint8) # record_bytes 的第一个字节表示标签, result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) #把标签去掉,剩下的就是image_bytes 把[weight * height * depth] 变成 [depth, height, weight] depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes+image_bytes]), [result.depth, result.height, result.width]) #把[depth, height, width] 转成 [height, width, depth] result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result #构建一个batch的数据, 多线程处理 def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle): ''' :param image: 3D tensor of [height, width, depth] of type.float32 :param label: 1D tensor of ty.int32 :param min_queue_examples: 队列中最少保留样本数,防止shuffle操作无效 :param batch_size: 一个batch有多少样本 :param shuffle: 是否打乱 :return: image和label [batch_size , height,width, 3] and [batch_size ] ''' num_preprocess_threads = 16 if shuffle: images, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue= min_queue_examples) else: image, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size) #将image添加到tensorboard tf.summary.image('images', images) #返回处理后的结果 images[batch_size, height,width,3] return images, tf.reshape(label_batch, [batch_size]) #数据预处理加读取 def distorted_inputs(data_dir, batch_size): ''' :param data_dir: 数据的路径 :param batch_size: 每个batch里图片数量 :return: image: 4D tensor of [batch_size, image_size, image_size , 3] label: 1D tensor of [batch_size] ''' #读取cifar10的数据文件 filenames =[os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)] # files = tf.train.match_filenames_once(os.path.join(data_dir, 'data_batch_*.bin')) # filename_queue = tf.train.string_input_producer(files) #判断文件是否都存在,如果有不存在的 直接抛出错误 for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: '+ f) #创建一个文件队列 filename_queue = tf.train.string_input_producer(filenames) with tf.name_scope('data_augmentation'): #从文件队列的文件中读取样本, 结果是[height, width, depth] read_input = read_image(filename_queue) #转成float32 reshaped_image = tf.cast(read_input.uint8image, tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE distored_image = tf.random_crop(reshaped_image, [height, width, 3]) #随机翻转 distored_image = tf.image.random_flip_left_right(distored_image) #随机调整亮度 distored_image = tf.image.random_brightness(distored_image,max_delta= 63) #随机调整对比度 distored_image = tf.image.random_contrast(distored_image, lower=0.2, upper= 1.8) #像素归一化 float_image = tf.image.per_image_standardization(distored_image) #调整一下shape float_image.set_shape([height, width, 3]) read_input.label.set_shape([1]) #定义要保留的比例 min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) print('Filling queue with %d CIFAR images before starting to train . This will take a few minutes' % min_fraction_of_examples_in_queue) return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle = True) #单纯的读取数据,不打乱也不预处理 def inputs(eval_data, data_dir, batch_size): #读取训练数据 if not eval_data: filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN else: filenames = [os.path.join(data_dir, 'test_batch.bin')] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) with tf.name_scope('input'): filename_queue = tf.train.string_input_producer(filenames) read_input = read_image(filename_queue) reshaped_image = tf.cast(read_input.uint8image, tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE #将图像统一大小 resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, height, width) #图像归一化 float_iamge = tf.image.per_image_standardization(resized_image) float_iamge.set_shape([height, width, 3]) read_input.label.set_shape([1]) min_fraction_of_examples_in_queue = 0.4 min_queue_example = int(num_examples_per_epoch * min_fraction_of_examples_in_queue) return _generate_image_and_label_batch(float_iamge, read_input.label, min_queue_example, batch_size, shuffle=False)