tensorflow随笔-条件循环控制(8)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u010255642/article/details/82377777

tf.while_loop

 

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    name=None,
    maximum_iterations=None,
    return_same_structure=False
)

当条件谓词cond为True,重复body

cond是一个调用,返回一个boolean的标量tensor,body是一个调用,返回一个(可能嵌套的)元组。tensors的命名元组或列表与loop_vars有相同数量(长度和结构)和类型,loop_vars是一个(可能嵌套的)元祖。tensors的命名元组或列表传送到cond和body中。cond和body都采用和loop_vars一样多的参数。

除了常规张量或索引片断外,body还可以接受和返回TensorArray对象。TensorArray的流动将在循环期间和梯度计算期间适当地传送。

while_loop调用cond和body刚好一次(在 while_loop内部,而不是 Session.run()间。while_loop将cond和body调用期间创建的计算图片段与一些附加的图节点缝合在一起,重复body直到cond返回false。

为了正确起见,tf.while_loop() 严格执行循环变量的形状不变量。形状不变量是一个(可能是部分的)形状,在循环的迭代过程中是不变的。如果迭代后的循环变量的形状被确定为比其形状不变量更通用或不兼容,则会产生错误。例如,[11, None]的形状比[11, 17 ]的形状更一般,并且[11, 21 ]与[11, 17 ]不兼容。默认情况下(如果没有指定参数shape_invariants),则假定loop_vars中的每个张量的初始形状在每次迭代中都是相同的。 shape_invariants参数允许调用者为每个循环变量指定一个不太特定的形状不变量,如果形状在迭代之间发生变化,则需要使用该形状不变量。还可以在body函数中使用tf.Tensor.set_shape函数来指示输出循环变量具有特定的形状。SparseTensor和索引片断的形状不变量被特别地处理如下:

a)如果循环变量是稀疏张量(SparseTensor),则形状不变量必须是张量形状(TensorShape[r]),其中r是由稀疏张量表示的稠密张量的秩。这意味着闪光灯的三个张量的形状是([None], [None, r], [r])。注意:这里的形状不变量是SparseTensor.dense_shape形状属性的形状。这必须是一个vector的形状。

b)如果循环变量是IndexedSlices,则形状不变量必须是IndexedSlices的值张量的形状不变量。这意味着索引数组(IndexedSlice)的三个张量的形状是(shape, [shape[0]], [shape.ndims])

while_loop实现非严格语义,使多个迭代并行运行。并行迭代的最大次数可以由parallel_iterations控制,这允许用户控制内存消耗和执行顺序。对于正确的程序,HyyLoad应该为任何parallel_iterations>0返回相同的结果。

对于训练,TensorFlow存储张量(tensors),这些张量在前向传递和反射传播中产生,这些张量是内存消耗的主要来源,并且经常在GPU上训练时导致OOM错误。当标志swap_memory为真时,我们将这些张量从GPU交换到CPU。例如,这允许我们训练具有很长序列和大批量的RNN模型。

参数:

  • cond:可调用的表示循环终止条件。 
  • body: 表示循环体的可调用体。
  • loop_vars: 一个(可能嵌套)元组、namedtuple或numpy数组列表、张量、和TensorArray 对象
  • shape_invariants: 循环变量的形状不变量。
  • parallel_iterations: 允许并行运行的迭代次数。它必须是正整数。
  • back_prop: 是否在这个while循环中启用了backprop。
  • swap_memory: 是否为这个循环启用了GPU-CPU内存交换。
  • name: 返回的张量的可选名称前缀。
  • maximum_iterations: 可选while循环运行的最大迭代次数. If provided, the cond output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than maximum_iterations.
  • return_same_structure: If True, output has same structure as loop_vars. If eager execution is enabled, this is ignored (and always treated as True).

Returns:

The output tensors for the loop variables after the loop. If return_same_structure is True, the return value has the same structure as loop_vars. If return_same_structure is False, the return value is a Tensor, TensorArray or IndexedSlice if the length of loop_vars is 1, or a list otherwise.

Raises:

  • TypeError: if cond or body is not callable.
  • ValueError: if loop_vars is empty.

Example:

 

i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])

Example with nesting and a namedtuple:

 

import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)

Example using shape_invariants:

 

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])

Example which demonstrates non-strict semantics: In the following example, the final value of the counter i does not depend on x. So the while_loop can increment the counter parallel to updates of x. However, because the loop counter at one loop iteration depends on the value at the previous iteration, the loop counter itself cannot be incremented in parallel. Hence if we just want the final value of the counter (which we print on the line print(sess.run(i))), then x will never be incremented, but the counter will be updated on a single thread. Conversely, if we want the value of the output (which we print on the line print(sess.run(out).shape)), then the counter may be incremented on its own thread, while x can be incremented in parallel on a separate thread. In the extreme case, it is conceivable that the thread incrementing the counter runs until completion before x is incremented even a single time. The only thing that can never happen is that the thread updating x can never get ahead of the counter thread because the thread incrementing xdepends on the value of the counter.

 

import tensorflow as tf

n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x: i < n
b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.Session() as sess:
    print(sess.run(i))  # prints [0] ... [9999]

    # The following line may increment the counter and x in parallel.
    # The counter thread may get ahead of the other thread, but not the
    # other way around. So you may see things like
    # [9996] x:[9987]
    # meaning that the counter thread is on iteration 9996,
    # while the other thread is on iteration 9987
    print(sess.run(out).shape)

猜你喜欢

转载自blog.csdn.net/u010255642/article/details/82377777
今日推荐