tf.contrib.seq2seq.dynamic_decode 返回值的shape 巨坑

tf.contrib.seq2seq.dynamic_decode 这个函数真的是巨坑啊,每一个batch,他的rnn_output的shape居然会是:
[batch_size, max_efficient_sentence across entire batch, num_classes] ,而不是我们想要的每个序列应该有的最大长度,尽管我们在TrainingHelper里面指定了Sequence_length的有效长度。。。。我日哦。
于是在计算sequence_loss的时候,

def sequence_loss(logits,
                  targets,
                  weights,
                  average_across_timesteps=True,
                  average_across_batch=True,
                  softmax_loss_function=None,
                  name=None):

我们传入的targets是每条序列都填充到某个固定的最大长度,然而logits的sequence大小却是又都是填充到 当前batch里面最大有效长度 ,这两个长度不相等.着实坑人···

举个例子:机器翻译里面,目标语词典大小1100,我们每次送5条目标语序列到模型,送入之前都会把这5个句子填充到指定的长度max_len(例如max_len=200),那么targets的shape就是[5,200,1100]。

然后我们也会保留目标语序列填充前的真实长度数组sequence_lengths,比如说sequence_lengths=[32,43,96,44,76] 。 坑爹的dynamic_decode执行后,当前batch的结果rnn_output的shape居然是[5,96,1100]!!!!!!!!!
!!!
!!!
!!!

解决方法:对targets进行截断,截断成和rnn_output一样的形状。

# 获取logits
logits = tf.contrib.seq2seq.dynamic_decode(xxx)[0].rnn_output

# 获取当前的长度,max_len 和 logits 的较小者.
current_ts = tf.to_int32(tf.minimum(tf.shape(target_input)[1], tf.shape(logits)[1]))
# 对 target 进行截取
target_sequence = tf.slice(target_input, begin=[0, 0], size=[-1, current_ts])
mask_ = tf.sequence_mask(lengths=sequence_lengths, maxlen=current_ts, dtype=logits.dtype)
logits = tf.slice(logits, begin=[0, 0, 0], size=[-1, current_ts, -1])
发布了307 篇原创文章 · 获赞 268 · 访问量 56万+

猜你喜欢

转载自blog.csdn.net/jmh1996/article/details/103647127