attention_masked
和 attention_weights
是在注意力机制中计算得到的两个重要结果。
- 其中,
attention_weights
表示输入序列中各个时间步对于当前时刻(解码器输出序列中的某一时间步)的权重。具体来说,对于给定的解码器时间步 t 和每个编码器时间步 s,我们都可以计算出一个对应的attention_weights[t,s]
值,表示解码器时间步 t 应该对编码器时间步 s 进行多大程度的关注,以获得更准确的输出。- 而
attention_masked
则是通过将输入序列的加权和与输出序列进行拼接,再通过全连接层进行变换得到的注意力加权后的输出结果。其实际作用是将编码器的信息传递给解码器,同时将解码器输出与编码器信息结合起来生成更准确的结果。
具体来说,attention_masked
的计算方式为:
attention_masked = tanh(Concatenate(decoder_output, attention_context)) * decoder_mask
其中:
decoder_output
是解码器输出序列,attention_context
是输入序列对于当前时刻的加权和,decoder_mask
是解码器的 padding mask。
因此,attention_masked
是将 decoder_output
和 attention_context
进行拼接和变换后得到的加权向量,体现了输入序列对输出序列的影响。
- 因此,
attention_weights
和attention_masked
之间的关系是attention_weights
用于计算输入序列中各个时间步对于当前时刻的权重,- 而
attention_masked
则根据这些权重,将输入序列进行加权求和,并与输出序列进行拼接,最终得到注意力机制加权后的输出结果。
代码示例+ 解释
def forward(self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states): # pylint: disable=arguments-differ
"""
Performs the forward pass.
:param padded_seqs: A tensor with the output sequences (batch, seq_d, dim).
:param seq_lengths: A list with the length of each output sequence.
:param encoder_padded_seqs: A tensor with the encoded input scaffold sequences (batch, seq_e, dim).
:param hidden_states: The hidden states from the encoder.
:return : Three tensors: The output logits, the hidden states of the decoder and the attention weights.
"""
padded_encoded_seqs = self._embedding(padded_seqs)
packed_encoded_seqs = tnnur.pack_padded_sequence(padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False)
packed_encoded_seqs, hidden_states = self._rnn(packed_encoded_seqs, hidden_states)
padded_encoded_seqs, _ = tnnur.pad_packed_sequence(packed_encoded_seqs, batch_first=True) # (batch, seq, dim)
mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float)
attn_padded_encoded_seqs, attention_weights = self._attention(padded_encoded_seqs, encoder_padded_seqs, mask)
logits = self._linear(attn_padded_encoded_seqs)*mask # (batch, seq, voc_size)
return logits, hidden_states, attention_weights
其中 attn_padded_encoded_seqs == attention_masked
mask:
mask
是一个张量,用于在输入序列的填充位置处将输出置零以避免计算填充标记对损失函数的影响。它的形状为 (batch, seq_d, 1)
。
具体来说,mask
初始时被设置为一个布尔值张量,其元素值为 True
或 False
,表示输入序列中是否包含填充标记。然后,将这个布尔值张量转换为浮点数张量并进行扩展,以匹配加权嵌入序列的形状。最后,通过将相应位置的加权嵌入元素与 mask
相乘,将所有填充标记位置的输出置为零。
logits:
logits = self._linear(attn_padded_encoded_seqs)*mask
表示对注意力加权嵌入向量形成的张量 attn_padded_encoded_seqs
进行线性变换,并将一些元素(在输入序列中对应填充标记位置的元素)置零,得到一个概率分布,表示输出序列中每个词汇表单词出现的可能性。
具体来说,self._linear
是一个全连接层(线性变换),将输入的注意力加权嵌入向量的形状 (batch, seq_d, hidden_size)
转换为输出 logits 的形状 (batch, seq_d, voc_size)
,其中 voc_size
是词汇表的大小。然后,通过将 mask
与线性变换的结果相乘,所有填充标记位置的 logits 值将被置零,不再参与计算生成概率分布的过程。
最终输出的 logits
张量是一个三维张量,表示输出序列中每个位置上所有可能单词的概率,可以通过传递给 softmax 函数来获得归一化的概率分布。