class Attention(nn.Module):
"""
inputs是[batch, seq_len1, hidden_dim]
labels_num是seq_len2
"""
def __init__(self, labels_num, hidden_size):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_size, labels_num, bias=False)
nn.init.xavier_uniform_(self.attention.weight)
def forward(self, inputs, masks):
masks = torch.unsqueeze(masks, 1) # [batch, 1, seq_len1]
attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf) # attention 是 [batch, labels_num, seq_len1]
attention = F.softmax(attention, -1)
return attention @ inputs # return结果 [batch, labels_num, hidden_size]
一种输入[batch, seq_len1, hidden_dim]输出[batch, seq_len2, hidden_dim]的self-attention的pytorch实现
猜你喜欢
转载自blog.csdn.net/guotong1988/article/details/103559258
今日推荐
周排行