在读这篇文章的时候遇到了以下代码:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = cell.zero_state(batch_size, tf.float32)
rnn_outputs, final_states = \
tf.scan(lambda a, x: cell(x, a[1]),
tf.transpose(rnn_inputs, [1,0,2]),
initializer=(tf.zeros([batch_size, state_size]), init_state))
这里来解释一下:
首先,tf.scan 第一个输入是函数,也就是:
tf.scan(lambda a, x: cell(x, a[1])
等价于(未验证,仅作illustration用):
def func(a, x):
return cell(x, a[1])
x是输入,a是上一步函数func的输出。为什么输入cell的是a[1]呢?这是因为,根据官方文档,MultiRNNCell的输出是:
Returns:
A pair containing:
Output: A 2-D tensor with shape [batch_size, self.output_size].
New state: Either a single 2-D tensor, or a tuple of tensors matching the arity and shapes of state.
换句话说,就是: (output,New state),也就是a。
那么:
a = (output, new_state)
a[0] = output
a[1] = new_state
所以,cell的输入,其一是x,也就是每一个time step的输入,其二是a[1],也就是上一个time step 输出的hidden state。
然后,tf.scan 的第二个输入是input,这个没什么好说的,需要注意数据的形状要从[batch_size,num_steps, state_size] 调整为[num_steps, batch_size, state_size]。tf.scan 会一步一步的把input输入cell,每次的形状是:[batch_size, state_size]
tf.scan第三个参数是a的初始化,那么水到渠成,它分别初始化了output和new_state:
initializer=(tf.zeros([batch_size, state_size]), init_state))
至此这句代码就分析完毕了。有不懂的同学还请细细钻研,弄懂了就不难。另附例子如下:
扫描二维码关注公众号,回复:
4911505 查看本文章
def testScan_SingleInputMultiOutput(self):
with self.test_session() as sess:
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = (np.array(1.0), np.array(-1.0))
r = tf.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, initializer)
r_value = sess.run(r)
self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])