我在使用Bert时,需要自己实现一个公式,查阅了很多博客,不知道怎么添加可训练的矩阵参数,自己写了一个简单的实现,不清楚公式中的v和W是否属于可训练的参数,如有大神路过,烦请解惑,十分感谢。
class IntentAdd(nn.Module):
def __init__(self, hidden):#hidden=768
super(IntentAdd, self).__init__()
self.hidden = hidden
self.v = nn.Parameter(torch.rand(self.hidden))
self.W = nn.Parameter(torch.rand(self.hidden))
def forward(self, snips_output, pooled_output):
return self.v * torch.tanh(snips_output + self.W * pooled_output)