在使用TensorFlow进行异步计算时,队列是一种强大的机制。
为了感受一下队列,让我们来看一个简单的例子。我们先创建一个“先入先出”的队列(FIFOQueue),并将其内部所有元素初始化为零。然后,我们构建一个TensorFlow图,它从队列前端取走一个元素,加上1之后,放回队列的后端。慢慢地,队列的元素的值就会增加。
TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。
队列概述
队列,如FIFOQueue和RandomShuffleQueue,在TensorFlow的张量异步计算时都非常重要。
例如,一个典型的输入结构:是使用一个RandomShuffleQueue来作为模型训练的输入:
- 多个线程准备训练样本,并且把这些样本推入队列。
- 一个训练线程执行一个训练操作
同步执行队列。
import tensorflow as tf
#模拟一下同步,先处理数据,然后才能取数据训练
#1、首先定义队列,存储3个数,类型是float32
q = tf.FIFOQueue(3, tf.float32)#先进先出
#放入一些数据
enq_many = q.enqueue_many([[0.1,0.2,0.3],])
#2、定义一些读取数据、取数据过程
out_q = q.dequeue()#取出数据
data = out_q + 1 #+1
en_q = q.enqueue(data)#载入队列
with tf.Session() as sess:
#初始化队列
sess.run(enq_many)
#处理数据
for i in range(100):
sess.run(en_q)
#训练数据
for i in range(q.size().eval()):
print (sess.run(q.dequeue()))
tf.QueueRunner
QueueRunner类会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。此外,一个QueueRunner会运行一个closer thread,当Coordinator收到异常报告时,这个closer thread会自动关闭队列。
您可以使用一个queue runner,来实现上述结构。 首先建立一个TensorFlow图表,这个图表使用队列来输入样本。增加处理样本并将样本推入队列中的操作。增加training操作来移除队列中的样本。
tf.Coordinator
Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
- should_stop():如果线程应该停止则返回True。
- request_stop(): 请求该线程停止。
- join():等待被指定的线程终止。
首先创建一个Coordinator对象,然后建立一些使用Coordinator对象的线程。这些线程通常一直循环运行,一直到should_stop()返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop(),同时其他线程的should_stop()将会返回True,然后都停下来。
异步执行队列:
import tensorflow as tf
#模拟异步,子线程存入样本, 主线程读取样本
#1、定义一个队列,1000
q = tf.FIFOQueue(1000,tf.float32)
#2、定义子线程要做的事情,循环, 值+1,放入队列
var = tf.Variable(0.0)
#实现自增 :tf.assign_add()
data = tf.assign_add(var,tf.constant(1.0))
en_q = q.enqueue(data)
#3、定义队列管理器op,指定子线程干什么,指定子线程个数
q1 = tf.train.QueueRunner(q, enqueue_ops = [en_q]*2)
#初始化变量的op
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#开启线程管理器
coord = tf.train.Coordinator()
#真正开启子线程
threads = q1.create_threads(sess,coord=coord,start = True)
#主线程,读取数据
for i in range(300):
print (sess.run(q.dequeue()))
#回收
coord.request_stop()
coord.join(threads)