tf.TensorArray和tf.while_loop组合使用

TensorArray

TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
常用方法有

  • write(index,value):将value写入TensorArray的第index个位置
  • stack:将TensorArray中的值作为Tensor返回

while_loop

final_state = tf.while_loop(cond, loop_body, init_state),作用是循环处理某个变量,中间处理的结果用来进行下一次处理,最后输出经过数次加工的变量,由于TensorArray可以动态扩展,因此常用来存储中间结果。

  • cond:是一个函数,负责判断继续执行循环的条件。
  • loop_body:是每个循环体内执行的操作,负责对循环状态迸行更新。
  • init_state:为循环的起始状态,它可以包含多个Tensor 或者 TensorArray 。

如果用伪代码来表示运行逻辑的话,那 tf.while_loop 的功能与下面的代码相当 :

def while_loop(cond, loop_body, init_state): 
    state = init_state 
    while(cond(state)) :   # 使用cond函数判断是否达到循环结束条件。
        state = loop_body(state)   # 使用loop_body函数对state进行更新。
    return state 

例子:

import tensorflow as tf

def condition(time, output_ta_l):
    return tf.less(time, 3)  # 真值比较 time小于3返回True 否则False

def body(time, output_ta_l):
    output_ta_l = output_ta_l.write(time, [2.4, 3.5])
    return time + 1, output_ta_l

time = tf.constant(0)
output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
print(output_ta)
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x0000016135AB4D88>
result = tf.while_loop(condition, body, loop_vars=[time, output_ta])
last_time, last_out = result
final_out = last_out.stack()
print(last_time.numpy())
>>> 3
print(last_out) # 还未解析的TensorArray
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x000002189D415E08>
print(final_out.numpy()) # time从0到3,一共向last_out写了三次
>>> [[2.4 3.5]
	 [2.4 3.5]
	 [2.4 3.5]]

参考

发布了83 篇原创文章 · 获赞 4 · 访问量 5354

猜你喜欢

转载自blog.csdn.net/weixin_43486780/article/details/105488849
tf