文章目录
0. BasicSeq2Seq
先从入口看起,BasicSeq2Seq
类继承的是Seq2SeqModel
类,下面是关于解码的部分。可以看到训练和预测阶段的解码方式是不同的。
@templatemethod("decode")
def decode(self, encoder_output, features, labels):
decoder = self._create_decoder(encoder_output, features, labels)
if self.use_beam_search:
decoder = self._get_beam_search_decoder(decoder)
bridge = self._create_bridge(
encoder_outputs=encoder_output,
decoder_state_size=decoder.cell.state_size)
if self.mode == tf.contrib.learn.ModeKeys.INFER:
return self._decode_infer(decoder, bridge, encoder_output, features,
labels)
else:
return self._decode_train(decoder, bridge, encoder_output, features,
labels)
了解了上面这个函数之后,我们接下来会从两方面继续介绍,一个当然是我们这篇文章要介绍的BeamSearchDecoder
了,它通过_get_beam_search_decoder
返回;另一个则是bridge
,因为这个变量在论文中并没有体现,我们就先来研究一下他是什么吧。
1.Bridge类
这个我是在代码中看到的,论文中并没有。
bridge
定义了信息在编码器、解码器之间是如何传递的,所以在编码器和解码器之间是有很多bridge
链接的。
比如,encoder
之后的是一个
的向量
,而decoder
却需要一个[batch size, n]
的输入向量
,
和
是可以不一样的。这时就需要bridge
类通过不同的逻辑,将
转化为
.
来看一下基类的实现:
@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
"""一个抽象类,定义信息如何在解码器编码器之间传输。
Args:
encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
decoder_state_size: An integer or tuple of integers defining the
state size of the decoder.
"""
def __init__(self, encoder_outputs, decoder_state_size, params, mode):
Configurable.__init__(self, params, mode)
self.encoder_outputs = encoder_outputs
self.decoder_state_size = decoder_state_size
self.batch_size = tf.shape(
nest.flatten(self.encoder_outputs.final_state)[0])[0]
def __call__(self):
"""Runs the bridge function.
Returns:
An initial decoder_state tensor or tuple of tensors.
"""
return self._create()
@abc.abstractmethod
def _create(self):
""" Implements the logic for this bridge.
This function should be implemented by child classes.
Returns:
A tuple initial_decoder_state tensor or tuple of tensors.
"""
raise NotImplementedError("Must be implemented by child class")
所有的逻辑都在 _create
函数中,具体实现由子类去完成, 该函数返回的是解码器的初始状态。
Bridge
有三个子类:ZeroBridge、
1.1 ZeroBridge
编解码器之间什么信息都不传,让解码器初始状态位0.
class ZeroBridge(Bridge):
"""A bridge that does not pass any information between encoder and decoder
and sets the initial decoder state to 0. The input function is not modified.
"""
@staticmethod
def default_params():
return {}
def _create(self):
zero_state = nest.map_structure(
lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32),
self.decoder_state_size)
return zero_state
1.2 PassThroughBridge
当且仅当解码器、编码器有相同的状态size(比如使用相同的rnn)时,可以使用,此时 。此时直接把编码器的输出喂给解码器。
class PassThroughBridge(Bridge):
"""Passes the encoder state through to the decoder as-is. This bridge
can only be used if encoder and decoder have the exact same state size, i.e.
use the same RNN cell.
"""
@staticmethod
def default_params():
return {}
def _create(self):
nest.assert_same_structure(self.encoder_outputs.final_state,
self.decoder_state_size)
return self.encoder_outputs.final_state
1.3 InitialStateBridge
没有什么问题是不能通过架一层来解决的~所以当
时,我们通过一个全连接FC
层来完成
到
的映射.
看起来这个是最常用的。而实际从代码上看,也确实使用了这种Bridge
class InitialStateBridge(Bridge):
"""A bridge that creates an initial decoder state based on the output
of the encoder. This state is created by passing the encoder outputs
through an additional layer to match them to the decoder state size.
The input function remains unmodified.
Args:
encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
decoder_state_size: An integer or tuple of integers defining the
state size of the decoder.
bridge_input: Which attribute of the `encoder_outputs` to use for the
initial state calculation. For example, "final_state" means that
`encoder_outputs.final_state` will be used.
activation_fn: An optional activation function for the extra
layer inserted between encoder and decoder. A string for a function
name contained in `tf.nn`, e.g. "tanh".
"""
def __init__(self, encoder_outputs, decoder_state_size, params, mode):
super(InitialStateBridge, self).__init__(encoder_outputs,
decoder_state_size, params, mode)
if not hasattr(encoder_outputs, self.params["bridge_input"]):
raise ValueError("Invalid bridge_input not in encoder outputs.")
self._bridge_input = getattr(encoder_outputs, self.params["bridge_input"])
self._activation_fn = locate(self.params["activation_fn"])
@staticmethod
def default_params():
return {
"bridge_input": "final_state",
"activation_fn": "tensorflow.identity",
}
def _create(self):
# Concat bridge inputs on the depth dimensions
bridge_input = nest.map_structure(
lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
self._bridge_input)
bridge_input_flat = nest.flatten([bridge_input])
bridge_input_concat = tf.concat(bridge_input_flat, 1)
state_size_splits = nest.flatten(self.decoder_state_size)
total_decoder_state_size = sum(state_size_splits)
# Pass bridge inputs through a fully connected layer layer
initial_state_flat = tf.contrib.layers.fully_connected(
inputs=bridge_input_concat,
num_outputs=total_decoder_state_size,
activation_fn=self._activation_fn)
# Shape back into required state size
initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
return nest.pack_sequence_as(self.decoder_state_size, initial_state)
2. BeamSearchDecoder类
其实,除了我们要讲的beam search encoder
,还有带attention的encoder
,当然这些都是从最基本的decoder发展出来的。
A decoder that uses beam search. Can only be used for inference, not training.
如果解码使用beamsearch,那么batch_size
要设置成beam_width
class BeamSearchDecoder(RNNDecoder):
"""The BeamSearchDecoder wraps another decoder to perform beam search instead
of greedy selection. This decoder must be used with batch size of 1, which
will result in an effective batch size of `beam_width`.
"""
def __init__(self, decoder, config):
"""
Args:
decoder: 一个`RNNDecoder` 的实例,就是使用了rnncell然后再包装一下
config: 包含了各种参数
"""
super(BeamSearchDecoder, self).__init__(decoder.params, decoder.mode,
decoder.name)
self.decoder = decoder
self.config = config
下面我们看一下,BeamSearchDecoder
的每一步step
在做什么:
首先,拿到最初的decoder状态和输出
(decoder_output, decoder_state, _, _) = \
self.decoder.step(time_, inputs, decoder_state)
其次, 执行这一步的beam search
,返回的是这一步beam search的输出和状态。
bs_output, beam_state = beam_search.beam_search_step(
time_=time_,
logits=decoder_output.logits,
beam_state=beam_state,
config=self.config)
其中,time_
是每一个时间步,从0开始,这时我们认为所有的beams都是相同的。
logits
是一个[B, vocab_size]
的tensor,表明当前时刻的logits;beam_state
是当前时刻的状态,是一个BeamState实例
;config
则是相关参数。
2.1 step中
我们深入这个函数看一下:
def beam_search_step(time_, logits, beam_state, config):
"""
Args:
释义见代码下方的文字
Returns:
"""
# 计算当前预测结果的长度
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
# 计算新假设的总概率大小(取log),维度[beam_width, vocab_size]
probs = tf.nn.log_softmax(logits)
## 把所有已经结束了的树枝`mask`起来,不会继续向下生长
probs = mask_probs(probs, config.eos_token, previously_finished)
## 对于所有既不是终止符也没有停止生长的`continuations`,加1
total_probs = tf.expand_dims(beam_state.log_probs, 1) + probs
# 计算`continuations`的长度(包含词数量)
lengths_to_add = tf.one_hot([config.eos_token] * config.beam_width,
config.vocab_size, 0, 1)
add_mask = (1 - tf.to_int32(previously_finished))
lengths_to_add = tf.expand_dims(add_mask, 1) * lengths_to_add
new_prediction_lengths = tf.expand_dims(prediction_lengths,
1) + lengths_to_add
# 计算每一个beamsearch结果的得分
scores = hyp_score(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
config=config)
scores_flat = tf.reshape(scores, [-1])
# 第一个时间步只考虑初始beam
scores_flat = tf.cond(
tf.convert_to_tensor(time_) > 0, lambda: scores_flat, lambda: scores[0])
# 通过specified successors function 找到下一个beam,详细内容见下面文字。
next_beam_scores, word_indices = \
config.choose_successors_fn(scores_flat, config)
# next_beam_scores.set_shape([config.beam_width])
word_indices.set_shape([config.beam_width])
# 根据我们选定的预测结果,取概率值, beamid, 和状态
total_probs_flat = tf.reshape(total_probs, [-1], name="total_probs_flat")
next_beam_probs = tf.gather(total_probs_flat, word_indices)
next_beam_probs.set_shape([config.beam_width])
next_word_ids = tf.mod(word_indices, config.vocab_size)
next_beam_ids = tf.div(word_indices, config.vocab_size)
# 将新的beam加入当前预测结果中 ?
next_finished = tf.logical_or(
tf.gather(beam_state.finished, next_beam_ids),
tf.equal(next_word_ids, config.eos_token))
# 计算下一次预测时beams的长度
# 1. 已经终止的beam不参与计算
# 2. 当前预测是终止符的beam不参与计算
# 3. 还没终止的beam长度加1
lengths_to_add = tf.to_int32(tf.not_equal(next_word_ids, config.eos_token))
lengths_to_add = (1 - tf.to_int32(next_finished)) * lengths_to_add
next_prediction_len = tf.gather(beam_state.lengths, next_beam_ids)
next_prediction_len += lengths_to_add
next_state = BeamSearchState(
log_probs=next_beam_probs,
lengths=next_prediction_len,
finished=next_finished)
output = BeamSearchStepOutput(
scores=next_beam_scores,
predicted_ids=next_word_ids,
beam_parent_ids=next_beam_ids)
return output, next_state
先说一下输入,
logits
就是当前时刻的logits,beam_state
定义在这里,包含了三项内容:“log_probs”(当前时刻,所有beam取 之后的概率值,就是可能出现哪些词), “finished”(beams是否结束,比如已经达到最大长度或者遇到了终止符), “lengths”(所有beams的长度(就是走到现在包含词个数))config
就是相关的参数啦
再说一下hyp_score
,这个函数会增加一个长度惩罚因子,这个思想来自2016年对谷歌NMT系统研究的论文。他的想法也很简单,因为我们每次得到的分都是负的,但是我们想让总分最大,这样一来,就会鼓励那些子长度越短、包含单词数越少的句子生成。这显然不是我们想要的结果。所以我们引入了一个长度惩罚因子
,取值
,对生成的句子长度进行一个规范。另外,$ \alpha
[0.6,0.7]$之间,
lp(Y) =\frac{(5+|Y|)^{\alpha}}{(5+1)^{\alpha}}
choose_successors_fn
的定义,和相关代码 ,所以这里直接使用的是choose_top_k
来找下一个beam。我们来看一下相关的函数:
def choose_top_k(scores_flat, config):
"""Chooses the top-k beams as successors.
"""
next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=config.beam_width)
return next_beam_scores, word_indices
2.2 step之后
接下来,会根据beamsearch的结果将所有打乱(??),然后封装结果输出。
2.3 完整step函数
def step(self, time_, inputs, state, name=None):
decoder_state, beam_state = state
# Call the original decoder
(decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs,
decoder_state)
# Perform a step of beam search
bs_output, beam_state = beam_search.beam_search_step(
time_=time_,
logits=decoder_output.logits,
beam_state=beam_state,
config=self.config)
# Shuffle everything according to beam search result
decoder_state = nest.map_structure(
lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state)
decoder_output = nest.map_structure(
lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output)
next_state = (decoder_state, beam_state)
outputs = BeamDecoderOutput(
logits=tf.zeros([self.config.beam_width, self.config.vocab_size]),
predicted_ids=bs_output.predicted_ids,
log_probs=beam_state.log_probs,
scores=bs_output.scores,
beam_parent_ids=bs_output.beam_parent_ids,
original_outputs=decoder_output)
finished, next_inputs, next_state = self.decoder.helper.next_inputs(
time=time_,
outputs=decoder_output,
state=next_state,
sample_ids=bs_output.predicted_ids)
next_inputs.set_shape([self.batch_size, None])
return (outputs, next_state, next_inputs, finished)
3. 总结
感觉beam-search有点像加了限制的BFS,限制宽度就是beam_size
.
通过代码也了解很多实现方法,比如infer过程遇到提前结束的beam怎么办、比如bridge等小细节,收获还是很大的!