一种输入[batch, seq_len1, hidden_dim]输出[batch, seq_len2, hidden_dim]的self-attention的pytorch实现

class Attention(nn.Module):
    """
    inputs是[batch, seq_len1, hidden_dim]
    labels_num是seq_len2
    """
    def __init__(self, labels_num, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size, labels_num, bias=False)
        nn.init.xavier_uniform_(self.attention.weight)

    def forward(self, inputs, masks):
        masks = torch.unsqueeze(masks, 1)  # [batch, 1, seq_len1]
        attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf)  # attention 是 [batch, labels_num, seq_len1]
        attention = F.softmax(attention, -1)
        return attention @ inputs   # return结果 [batch, labels_num, hidden_size]
发布了1142 篇原创文章 · 获赞 196 · 访问量 260万+

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/103559258