TensorFlow之队列与多线程(二)

学习记录

这篇博客是衔接上一篇中介绍的多线程对队列的操作的内容,继续介绍通过此方法对训练数据的读取以用于训练。这篇博客中的内容是基于假设所有的输入数据都已经整理成了TFRecord格式。关于将输入数据转换成TFRecord格式的这部分内容后续看情况是否写个博客来做个记录。这部分的学习依旧主要是参照着《TensorFlow实战Google深度学习框架》一书中对应的部分,以下也相当于是相应部分的一个阅读笔记。

输入文件队列

TFRecord文件

在正式介绍利用队列与多线程读取数据之前,先给出一个简单的程序来生成样例数据。

# 创建一个TFRecord的帮助文件
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 模拟海量数据情况下将数据写入不同的文件。num_shards定义了总共写入多少个文件
# instances_per_shard定义了每个文件中多少个数据
num_shards = 2
instances_per_shard = 2
for i in range(num_shards):
    # 将数据划分为多个文件时,可以将不同文件以类似0000n-of-0000m的后缀区分。
    # 其中m表示数据总共被存储在多少个文件中,n表示当前文件的标号
    # 式样的方式既方便了通过正则表达式获取文件列表,又在文件名中加入了更多的信息
    filename = ('./data.tfrecords-%.5d-of-%.5d' % (i, num_shards))
    writer = tf.python_io.TFRecordWriter(filename)
    # 将数据封装成example结构并写入TFRecord文件
    for j in range(instances_per_shard):
        example = tf.train.Example(features=tf.train.Features(feature={
                'i': _int64_feature(i),
                'j': _int64_feature(j)
                }))
        writer.write(example.SerializeToString())
    writer.close()

刚开始看这个代码,好多点都不太明白,下面对代码中一些点做个记录。

  1. tf.train.Feature
    tf.train.Feature有三个属性,分别为tf.train.bytes_list tf.train.float_list tf.train.int64_list,应用时只需要根据相应的值设置tf.train.Feature的属性即可。如:tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),在属性不同时‘int64_list=’也要做出相应的改变。
    特别地,我们还有:tf.train.Features。从名字看。我们应该能猜出它是tf.train.Feature的复数形式,事实上tf.train.Features有属性feature,这个属性的一般设置方法是传入一个字典,如:
feature_dict = { 
"data_id":tf.train.Feature(int64_list= tf.train.Int64List(value=[value1])),
 "data":tf.train.Feature(bytes_list= tf.train. bytesList(value=[value2])) 
}
features = tf.train.Features(feature=feature_dict)
  1. tf.train.Example
    tf.train.Example有一个属性为features,我们只需要将上一个解析中得到的结果再次当做参数传进来即可。另外,tf.train.Example还有一个方法SerializeToString(),这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们在写入文件的时候不能直接处理对象,将其转化为字符串才能处理。如:
example = tf.train.Example(features=features)
example_str = example.SerializeToString()
  1. tf.python_io.TFRecordWriter(path)
    打开路径所指向的文件并创建一个写入它的TFRecordWriter类,类中有函数writer()用于将字符串记录写入文件。

单个样例读取

下面先给出参考代码配合后面的理解。

# 使用tf.train.match_filenames_once函数获取文件列表
files = tf.train.match_filenames_once('./data.tfrecords-*')

# 通过tf.train.string_input_producer函数创建输入列表,输入队列中的文件列表为
# tf.train.match_filenames_once函数获取的文件列表。这里将shuffle参数设置为False
# 来避免随机打乱读文件的顺序。但一般在解决真实问题时,会将shuffle设置为True
filename_queue = tf.train.string_input_producer(files,shuffle=False)

#读取并解析一个样本
reader = tf.TFRecordReader()

aaa, serialized_example = reader.read(filename_queue) #返回文件名和文件

features = tf.parse_single_example(
        serialized_example,
        features={
                'i': tf.FixedLenFeature([], tf.int64),
                'j': tf.FixedLenFeature([], tf.int64),
                })

with tf.Session() as sess:
    # 虽然在本段程序中没有声明任何变量,但在使用tf.train.match_filenames_once
    # 函数时需要初始化一些变量
    tf.local_variables_initializer().run()
    
    print(sess.run(files))
    
    # 声明tf.train.Coordinator类来协同不同的线程,并启动线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    for i in range(7):
        print(sess.run([aaa,features['i'],features['j']]))
    coord.request_stop()
    coord.join(threads)

利用队列和多线程的概念管理输入文件列表和读取样例,主要涉及以下几个函数。

  1. tf.train.match_filenames_once()
    这个函数主要用于获取符合一个正则表达式的所有文件,得到一个文件列表。
    注意:在使用tf.train.match_filenames_once函数是需要初始化一些变量,具体可参考下面给出的代码。
  2. tf.train.string_input_producer()
    这个函数主要用来对文件列列表进行管理。tf.train.string_input_producer函数会使用初始化时提供的文件列表创建一个输入队列,输入队列中的原始的元素为文件列表中的所有文件。这样创建好的输入队列可以作为文件读取的参数。在这个函数中还有个参数shuffle可供设置,用以支持随机打乱文件列表中文件的出队顺序。当一个输入队列中所有文件都被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列,它也支持设置num_epochs参数来限制加载初始文件列表的最大轮数。
  3. tf.TFRecordReader()
    TFRecord文件的一个读取类,可通过该类的read()函数从上一个命令中产生的队列中读取,返回读取的文件名和文件
    TFRecord文件的一个读取类,可通过该类的read()函数从上一个命令中产生的队列中读取样例,放回读取的文件名和文件
  4. tf.parse_single_example(serialized,features,name=None,example_names=None)
    读取的数据一般得按照TFRecord文件当初存储的特征形式解析解析单个 Example 原型,这就是这个函数的作用。该函数返回一个存储着标签信息的dict。

组合训练数据(batching)

通过前面的介绍的方法,我们已经可以从文件列表中读取单个样例了,将这些单个样例进行一些预处理操作之后,就可以得到提供给神经网络输入层的训练数据了。但是我们知道,将多个样例组织成一个batch可以提高模型训练的效率,所以在得到单个样例之后还需要将它们组织成batch,然后再提供给神经网络的输入层。

下面照旧先贴一段代码。

files = tf.train.match_filenames_once('./data.tfrecords-*')
filename_queue = tf.train.string_input_producer(files,shuffle=False)
reader = tf.TFRecordReader()
aaa, serialized_example = reader.read(filename_queue)
feature = tf.parse_single_example(
        serialized_example,
        features={
                'i': tf.FixedLenFeature([], tf.int64),
                'j': tf.FixedLenFeature([], tf.int64),
                })

# 使用7.3.2节中的方法读取并解析得到样例。这里假设example结构中i表示一个样例的特征向量
# 比如一个图片的像素矩阵,而j表示该样例对应的标签
example, label = feature['i'], feature['j']

# 一个batch中样例的个数
batch_size = 3

# 组合样例的队列中最多可以存储的样例个数。这个队列如果太大,那么需要占用很多的内存资源
# 如果太小,那么出队操作可能会因为没有数据而被阻碍,从而导致训练效率降低。
# 一般来说,这个队列的大小会和每一个batch的大小相关,下面一行代码给出了设置队列大小的一种方式
capacity = 1000 + 3*batch_size

# 使用tf.train.batch函数来组合样例。[example,label]参数给出了需要组合的元素,
# 一般example,label分别代表训练样本和这个样本所对应的标签
# batch_size参数给出了每个batch中样例的个数
# capacity给出了队列的最大容量
# 当队列长度等于容量时,暂停入队操作,只等待元素出队
# 当元素个数小于容量时,将重新启动入队操作
example_batch, label_batch = tf.train.batch(
        [example,label],
        batch_size=batch_size,
        capacity=capacity)

with tf.Session() as sess:
#    tf.initialize_all_variables().run()
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    # 获取并打印组合后的样例。在真实问题中,这个输出一般会作为神经网络的输入
    for i in range(10):
        cur_example_batch, cur_label_batch = sess.run(
                [example_batch, label_batch])
        print(cur_example_batch,cur_label_batch)
    
    coord.request_stop()
    coord.join(threads)

TensorFlow提供了tf.train.batch和tf.train.shuffle_batch函数来将单个的样例组织成batch的输出形式。这两个函数都会生成一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例。它们唯一的区别就在于是否会将数据顺序打乱。上面的代码仅展示了tf.train.batch的使用,tf.train.shuffle_batch的使用与之类似,仅仅是多了一个min_after_dequeue用以限定队列中的最小长度。

本篇主要介绍利用队列对TFRecord文件进行读取,用以获取提供给神经网路输入层的训练数据。估计很多知识点有所缺漏,看后面发现不足再补吧。

猜你喜欢

转载自blog.csdn.net/weixin_43923472/article/details/89438551