1.简介
一直想将图片制作成tfrecords文件,然后在模型中运行一下。最初想用的数据集是mnist,但是跑的过程中一直出现问题。找到这一篇知乎上的博客,写的非常不错。
原博客地址:https://zhuanlan.zhihu.com/p/32490882
其代码地址:https://github.com/HelloSangShen/Cat-vs-Dog/
猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA 密码:dmp4
2. 本文结构
本文以kaggle的猫狗大战为例,完整地描述使用TensorFlow进行一次完整CNN训练的每个步骤。首先介绍如何将图片转为TFRecords文件,然后介绍如何读取该文件的数据并且输入给我们的网络进行训练,并且会展示如何通过hook来监测网络训练的情况(这里没有使用TensorBoard)。最后会简单解读一下MonitoredTrainingSession的使用方法。
3. 正文
3.1 数据处理
有过实践的小伙伴应该能感受到,当有了TensorFlow、PyTorch这样优秀的框架后,构造一个神经网络、进行训练、计算损失函数、预测等都变的相对容易许多。但是数据的预处理仍然是一个相对棘手的问题,尤其是在较大数据集上进行训练时,不能总是使用占位符(placeholder)和feed dict进行数据加载,而TensorFlow提供了另外一种加载方式。这部分就着重介绍如何将图片数据存储为TFRecords,并且通过队列读取给我们的网络。因为网上有非常多介绍TFRecords原理的文章,我这里就不细说了,只给出详细的代码和注释,示范一下如何处理。
def read_images(path): """从源文件/路径读取图像 参数: path: 图像所在的路径即文件夹名称 返回: 返回一个带有所有图像、标签和总数信息的对象 images: 所有的图像数据 labels: 所有标签 num: 数目 """ # 获取文件夹内所有图像文件的文件名和总数 filenames = next(walk(path))[2] num_file = len(filenames) # 初始化图像和标签 images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8) labels = np.zeros((num_file, ), dtype=np.uint8) # 遍历读取文件 for index, filename in enumerate(filenames): # 读取单张图像,并且修改为自定义尺寸 img = imread(join(path, filename)) img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT)) images[index] = img # TO DO # 这里通过文件名获取标签信息,猫狗大战问题中只有两类,故只有0和1 # 可以根据自己的需要进行修改 # 注意:这里不是one-hot编码 if filename[0:3] == 'cat': labels[index] = int(0) else: labels[index] = int(1) if index % 1000 == 0: print("Reading the %sth image" % index) # 创建一个类,该类携带图像、标签和总数信息 class ImgData(object): pass result = ImgData() result.images = images result.labels = labels result.num = num_file return result
通过上述函数,我们可以读取到文件夹内所有的图片。接下来,我们要把这些图片转为TFRecords文件。
def convert(data, destination): """将图片存储为.tfrecords文件 参数: data: 上述函数返回的ImageData对象 destination: 目标文件名 """ images = data.images labels = data.labels num_examples = data.num # 存储的文件名 filename = destination # 使用TFRecordWriter来写入数据 writer = tf.python_io.TFRecordWriter(filename) # 遍历图片 for index in range(num_examples): # 转为二进制 image = images[index].tostring() label = labels[index] # tf.train下有Feature和Features,需要注意其区别 # 层级关系为Example->Features->Feature example = tf.train.Example(features=tf.train.Features(feature={ 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) })) # 写入 writer.write(example.SerializeToString()) writer.close()
这两个函数就可以把我们的数据集图片全都写入一个.tfrecords文件。如果文件过大,可以写入多个文件。
下面介绍如何从tfrecords文件中批量读取图片和标签。
def read_and_decode(filename_queue): """读取.tfrecords文件 参数: filename_queue: 文件名, 一个列表 返回: img, label: **单张图片和对应标签** """ # 创建一个图节点,该节点负责数据输入 filename_queue = tf.train.string_input_producer([filename_queue]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # 解析单个example features = tf.parse_single_example(serialized_example, features={ 'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64) }) image = tf.decode_raw(features['image'], tf.uint8) image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL]) image = tf.cast(image, tf.float32) label = tf.cast(features['label'], tf.int64) return image, label
我们将数据读取的功能进行封装,代码如下:
def distorted_input(filename, batch_size):
"""建立一个乱序的输入
参数:
filename: tfrecords文件的文件名. 注:该文件名仅为文件的名称,不包含路径和后缀
batch_size: 每次读取的batch size
返回:
images: 一个4D的Tensor. size: [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]
labels: 1D的标签. size: [batch_size]
"""
# 完整文件名,文件存储在同一路径下的tfrecords文件夹下,名为filename.tfrecords
filename = './tfrecords/' + filename + '.tfrecords'
# 如果路径下没有该文件,说明没有进行转换工作,则将图片转为tfrecords文件
if not os.path.exists(filename):
print('Transfer images to TF_Records')
raw_data = read_images(FLAGS.raw_data_path)
convert(raw_data, filename)
print('End transfering')
image, label = read_and_decode(filename)
# 乱序读入一个batch
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
num_threads=16, capacity=3000, min_after_dequeue=1000)
return images, labels
以上,我们就完成了数据的读取部分了。下面用一段代码进行测试。
images, labels = catdog_input.distorted_input(FLAGS.tfrecords_file_name, batch_size=4) # from matplotlib import pyplot as plt fig = plt.figure() a = fig.add_subplot(221) b = fig.add_subplot(222) c = fig.add_subplot(223) d = fig.add_subplot(224) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() # 开启文件读取队列,开启后才能开始读取数据 threads = tf.train.start_queue_runners(sess=sess, coord=coord) img, label = sess.run([images, labels]) a.imshow(img[0]) a.axis('off') b.imshow(img[1]) b.axis('off') c.imshow(img[2]) c.axis('off') d.imshow(img[3]) d.axis('off') plt.show() coord.request_stop() coord.join(threads)
通过这个简单的测试程序就可以可视化四张图片出来。
3.2 模型
这里,我们使用VGG-16模型来做测试。TensorFlow在搭建网络上非常方便,这里就不给详细代码了(可以参考http://www.cs.toronto.edu/~frossard/post/vgg16/),读者可以在文末的GitHub链接上找到相关代码。
对于准确率、损失函数等,我们参考TensorFlow教程中Cifar10训练的源代码进行实现,将这些函数均封装起来。
def loss(logits, labels): labels = tf.cast(labels, tf.int64) # 注意:我们上面定义的标签不是one-hot编码,故这里调用的是sparse方法 # 如果使用one-hot,调用softmax_cross_entropy_with_logits即可 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name='cross_entropy_per_example') loss = tf.reduce_mean(cross_entropy, name='cross_entropy') return loss def accuracy(logits, labels): # 将labels转为one-hot编码进行计算 labels = tf.one_hot(labels, NUM_CLASS) correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) return accuracy def train(loss): train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) return train_op
至此,我们的模型就搭建好了。接下来就是训练步骤。
3.3 训练
下面的train()
函数也是参照cifar10的源码进行实现的。
def train(): # 因为要使用StopAtStepHook,故global_step是必须的 global_step = tf.train.get_or_create_global_step() # 输入 images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE) logits = catdog_model.inference(images) loss = catdog_model.loss(logits, labels) # accuracy = catdog_model.accuracy(logits, labels) train_op = catdog_model.train(loss) class _LoggerHook(tf.train.SessionRunHook): """ 该类用来打印训练信息 """ def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 # 该函数在训练运行之前自动调用 # 在这里返回所有你想在运行过程中查看到的信息 # 以list的形式传递,如:[loss, accuracy] return tf.train.SessionRunArgs(loss) def after_run(self, run_context, run_values): # 打印信息的步骤间隔 display_step = 10 if self._step % display_step == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time # results返回的就是上面before_run()的返回结果,上面是loss故这里是loss # 若输入的是list,返回也是一个list loss = run_values.results # 每秒使用的样本数 examples_per_sec = display_step * BATCH_SIZE / duration # 每batch使用的时间 sec_per_batch = float(duration / display_step) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step), tf.train.NanTensorHook(loss), _LoggerHook()], # 将上面定义的_LoggerHook传入 config=tf.ConfigProto( log_device_placement=False)) as sess: coord = tf.train.Coordinator() # 开启文件读取队列 threads = tf.train.start_queue_runners(sess=sess, coord=coord) while not sess.should_stop(): sess.run(train_op) coord.request_stop() coord.join(threads)
上面就是在猫狗大战数据集上进行的一个完整的图片数据预处理、数据读取、搭建网络、训练并监测的过程。
3.4 评估
因为实验室设备暂时有点问题,没法训练,故现在没法给出结果,以后训练出结果后再来更新吧。
3.5 关于MonitoredTrainingSession
我们在上面的训练中用到了tf.train.MonitoredTrainingSession(...)
。查阅了一下官方文档,该类继承自MonitoredSession类。我们先看看这个父类,官方文档中给了一段如下示例代码 :
saver_hook = CheckpointSaverHook(...) summary_hook = SummarySaverHook(...) with MonitoredSession(session_creator=ChiefSessionCreator(...), hooks=[saver_hook, summary_hook]) as sess: while not sess.should_stop(): sess.run(train_op)
首先,当MonitoredSession初始化的时候,会按顺序执行下面操作:
- 调用hook的
begin()
函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook
里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。 - 通过调用
scaffold.finalize()
初始化计算图 - 创建会话
- 通过初始化
Scaffold
提供的操作(op)来初始化模型 - 如果checkpoint存在的话,restore模型的参数
- launches queue runners
- 调用
hook.after_create_session()
然后,当run()
函数运行的时候,按顺序执行下列操作:
- 调用
hook.before_run()
- 调用TensorFlow的
session.run()
- 调用
hook.after_run()
- 返回用户需要的
session.run()
的结果 - 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话
最后,当调用close()退出时,按顺序执行下列操作:
- 调用
hook.end()
- 关闭队列和会话
- 阻止OutOfRange错误
需要注意的是:该类不是一个tf.Session()
,因为它不能被设置为默认会话,不能被传递给saver.save,也不能被传递给tf.train.start_queue_runners
,这也解释了为什么在开启会话后我们必须手动调用tf.train.start_queue_runners()
而MonitoredTrainingSession则比起父类多了许多其他的参数,可以在官方文档获取各参数的说明,这里我们不详细说。但是根据其父类的执行说明,我们就可以很容易理解上面train()
函数中发生了什么。
首先,我们先将计算图的各个节点/操作定义好,构成了一个计算图。然后开启了一个MonitoredTrainingSession来初始化/注册我们的图和其他信息。其中,我们给其传递了3个hook:
tf.train.StopAtStepHook(last_step)
,该hook主要是在训练到特定步数后即请求停止,使用该hook必须要预先定义一个tf.train.get_or_create_global_step()
。否则会抛出运行时错误,见源码:
def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.")
tf.train.NanTensorHook(loss)
,该hook用来监测loss,若loss的结果为NaN,抛出异常或者直接停止训练。_LoggerHook()
,该hook是我们自定义的hook,用来监测我们希望在训练过程中能查看的一些数据如loss或者accuracy。首先会随着MonitoredTrainingSession的初始化来调用begin()
函数,我们在这里初始化步数,before_run()
函数会随着sess.run()
的调用而调用。故每训练一步调用一次,这里返回想要打印的信息,随后就调用after_run()
函数,在这里,我们就将需要查看的信息打印出来即可。
随后,我们开启文件读取队列进行数据的输入。然后就一直调用sess.run()
训练直到停下。
4.如何运行
首先得生成tfrecords文件,在当前文件夹下新建一个create_tfrecords.py,然后将下面的代码放进去(其实就是上面的代码)
import tensorflow as tf import numpy as np import os from scipy.misc import imread,imresize from os.path import join from os import walk IMAGE_WIDTH = 224 IMAGE_HEIGHT = 224 IMAGE_CHANNEL = 3 NUM_CLASS = 2 def read_images(path): """Read image from source file/directory Args: path: source derectory Return: An object representing all images and labels, fields: images: all image data labels: all labels num: number of images """ # Get a list filenames filenames = next(walk(path))[2] num_file = len(filenames) # Initialize images and labels. images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8) labels = np.zeros((num_file, ), dtype=np.uint8) # Iterate/Read all files for index, filename in enumerate(filenames): # Read single image and resize it to your expected size img = imread(join(path, filename)) img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT)) images[index] = img # TO DO: if filename[0:3] == 'cat': labels[index] = int(0) else: labels[index] = int(1) if index % 1000 == 0: print("Reading the %sth image" % index) class ImgData(object): pass result = ImgData() result.images = images result.labels = labels result.num = num_file return result def convert(data, destination): """Convert images to tfrecords Args: data: an object of ImgData, consisting of images, labels and number of images destination: destination filename of tfrecords """ images = data.images labels = data.labels num_examples = data.num # filenale of tfrecords filename = destination writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): image = images[index].tostring() label = labels[index] # Attention: Example -> Features -> Feature example = tf.train.Example(features=tf.train.Features(feature={ 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) })) writer.write(example.SerializeToString()) writer.close() if __name__ == '__main__': path = 'kaggle/train' tfrecords_path = 'tfrecords/cat_dog.tfrecords' data = read_images(path) convert(data,tfrecords_paths)
然后直接命令python create_tfrecords.py
然后直接命令python catdog_train.py --tfrecords_name cat_dog \ --max_step 5000
运行结果: