版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013608336/article/details/82466772
tf官网解释链接
queue被实现为计算图中的节点,一个节点就像variable一样。
enqueue:入队,提前从硬盘读取数据
dequeue:出队,从队列中取一个mini_batch出来送进计算节点
主要有两个步骤
1.Multiple threads prepare training examples and enqueue them.
2.A training thread executes a training op that dequeues mini-batches from the queue
def simple_shuffle_batch(source, capacity, batch_size=10):
# Create a random shuffle queue.
queue = tf.RandomShuffleQueue(capacity=capacity,
min_after_dequeue=int(0.9*capacity),
shapes=source.shape, dtypes=source.dtype)
# Create an op to enqueue one item.
enqueue = queue.enqueue(source)
# Create a queue runner that, when started, will launch 4 threads applying
# that enqueue op.
num_threads = 4
qr = tf.train.QueueRunner(queue, [enqueue] * num_threads)
# Register the queue runner so it can be found and started by
# <a href="../../api_docs/python/tf/train/start_queue_runners"><code>tf.train.start_queue_runners</code></a> later (the threads are not launched yet).
tf.train.add_queue_runner(qr)
# Create an op to dequeue a batch
return queue.dequeue_many(batch_size)
Once started by tf.train.start_queue_runners,
or indirectly through tf.train.MonitoredSession
, the QueueRunner will launch the threads in the background to fill the queue.
# create a dataset that counts from 0 to 99
input = tf.constant(list(range(100)))
input = tf.data.Dataset.from_tensor_slices(input)
input = input.make_one_shot_iterator().get_next()
# Create a slightly shuffled batch from the sorted elements
get_batch = simple_shuffle_batch(input, capacity=20)
# `MonitoredSession` will start and manage the `QueueRunner` threads.
with tf.train.MonitoredSession() as sess:
# Since the `QueueRunners` have been started, data is available in the
# queue, so the `sess.run(get_batch)` call will not hang.
while not sess.should_stop():
print(sess.run(get_batch))
[ 8 10 7 5 4 13 15 14 25 0]
[23 29 28 31 33 18 19 11 34 27]
[12 21 37 39 35 22 44 36 20 46]
prefetch_queue
该函数是simple_shuffle_batch函数的综合,创建一个queue runner,将之注册到QueueRunner中
def prefetch_queue(tensors,
capacity=8,
num_threads=1,
dynamic_pad=False,
shared_name=None,
name=None):
"""Creates a queue to prefetch tensors from `tensors`.
A queue runner for enqueuing tensors into the prefetch_queue is automatically
added to the TF QueueRunners collection.
Example:
This is for example useful to pre-assemble input batches read with
`tf.train.batch()` and enqueue the pre-assembled batches. Ops that dequeue
from the pre-assembled queue will not pay the cost of assembling the batch.
images, labels = tf.train.batch([image, label], batch_size=32, num_threads=4)
batch_queue = prefetch_queue([images, labels])
images, labels = batch_queue.dequeue()
logits = Net(images)
loss = Loss(logits, labels)
Args:
tensors: A list or dictionary of `Tensors` to enqueue in the buffer.
capacity: An integer. The maximum number of elements in the queue.
#队列中的元素的最大个数
num_threads: An integer. Number of threads running the enqueue op.
#enqueue的线程数
dynamic_pad: Boolean. Whether to allow variable dimensions in input shapes.
shared_name: (optional). If set, this queue will be shared under the given
name across multiple sessions.
name: (Optional) A name for the operations.
Returns:
A queue from which you can dequeue tensors with the same type and shape
as `tensors`.