从一句代码看tf.scan

在读这篇文章的时候遇到了以下代码:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = cell.zero_state(batch_size, tf.float32)

rnn_outputs, final_states = \
       tf.scan(lambda a, x: cell(x, a[1]),
               tf.transpose(rnn_inputs, [1,0,2]),
               initializer=(tf.zeros([batch_size, state_size]), init_state))

这里来解释一下:

首先,tf.scan 第一个输入是函数,也就是:

tf.scan(lambda a, x: cell(x, a[1])

等价于(未验证,仅作illustration用):

def func(a, x):
    return cell(x, a[1])

x是输入,a是上一步函数func的输出。为什么输入cell的是a[1]呢?这是因为,根据官方文档,MultiRNNCell的输出是:

Returns:

A pair containing:

Output: A 2-D tensor with shape [batch_size, self.output_size].
New state: Either a single 2-D tensor, or a tuple of tensors matching the arity and shapes of state.

换句话说,就是: (output,New state),也就是a。

那么:

a = (output, new_state)
a[0] = output
a[1] = new_state

所以,cell的输入,其一是x,也就是每一个time step的输入,其二是a[1],也就是上一个time step 输出的hidden state。

然后,tf.scan 的第二个输入是input,这个没什么好说的,需要注意数据的形状要从[batch_size,num_steps, state_size] 调整为[num_steps, batch_size, state_size]。tf.scan 会一步一步的把input输入cell,每次的形状是:[batch_size, state_size]

tf.scan第三个参数是a的初始化,那么水到渠成,它分别初始化了output和new_state:

initializer=(tf.zeros([batch_size, state_size]), init_state))

至此这句代码就分析完毕了。有不懂的同学还请细细钻研,弄懂了就不难。另附例子如下

扫描二维码关注公众号,回复: 4911505 查看本文章
def testScan_SingleInputMultiOutput(self):
  with self.test_session() as sess:
    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    initializer = (np.array(1.0), np.array(-1.0))
    r = tf.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, initializer)
    r_value = sess.run(r)
 
    self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
    self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])

猜你喜欢

转载自blog.csdn.net/qq_20289205/article/details/86474578