pytorch实现给公式添加参数

我在使用Bert时,需要自己实现一个公式,查阅了很多博客,不知道怎么添加可训练的矩阵参数,自己写了一个简单的实现,不清楚公式中的v和W是否属于可训练的参数,如有大神路过,烦请解惑,十分感谢。

class IntentAdd(nn.Module):
    def __init__(self, hidden):#hidden=768
        super(IntentAdd, self).__init__()
        self.hidden = hidden
        self.v = nn.Parameter(torch.rand(self.hidden))
        self.W = nn.Parameter(torch.rand(self.hidden))
    def forward(self, snips_output, pooled_output):
        return self.v * torch.tanh(snips_output + self.W * pooled_output)
发布了50 篇原创文章 · 获赞 44 · 访问量 8896

猜你喜欢

转载自blog.csdn.net/tailonh/article/details/105372390