Transformer介绍
编码器层
class EncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
"""
self_attn:多头自注意力子层实例化对象, 并且是自注意力机制,
feed_froward:前馈全连接层实例化对象
"""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
x = self.sublayer[1](x, self.feed_forward)
return x
size = 512
head = 8
d_model = 512
d_ff = 2048
x = out_pe
dropout = 0.2
self_attn = MultiHeadedAttention(head, d_model, dropout)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
mask = torch.zeros(8, 4, 4)
el = EncoderLayer(size, self_attn, ff, dropout)
out_el = el(x, mask)
print(out_el)
print(out_el.shape)
tensor([[[ 0.7068, 0.2008, 5.9876, ..., -21.2479, 35.0978, 39.1085],
[ 0.4000, 23.3858, -10.8174, ..., 11.0451, 31.5901, -22.6388],
[-16.9983, -4.9498, -11.5392, ..., 0.2313, -26.8330, 10.9285],
[ 3.7644, 50.2599, 45.0382, ..., 6.7440, 14.0002, 0.1603]],
[[ 0.1804, -0.2181, 0.2617, ..., -8.6161, 0.3426, 12.8423],
[-15.5720, 8.4282, 5.0564, ..., -4.2342, 13.8840, 14.9587],
[ -7.3524, -7.0248, -35.8694, ..., -10.9833, 0.9038, -13.7717],
[ -0.2831, 37.5377, 16.6007, ..., 0.6348, -1.0165, -24.8043]]],
grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])
编码器
class Encoder(nn.Module):
def __init__(self, layer, N):
"""
layer:编码器层
N:编码器层个数
"""
super(Encoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
size = 512
head = 8
d_model = 512
d_ff = 2048
dropout = 0.2
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model, dropout)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
layer = EncoderLayer(size, c(attn), c(ff), dropout)
N = 8
mask = torch.zeros(8,4,4)
en = Encoder(layer, N)
out_en = en(x, mask)
print(out_en)
print(out_en.shape)
tensor([[[-0.7634, 1.3652, 0.6623, ..., -0.8390, -0.1195, -0.1396],
[ 0.5829, 0.4156, 0.1490, ..., 0.7758, 0.0204, -1.1847],
[-0.3356, -1.8417, 0.1236, ..., 1.4707, -1.5049, -0.4879],
[-0.1290, -0.9651, 0.1690, ..., -0.2859, -2.1153, -1.0485]],
[[-1.4313, -0.5716, -0.2115, ..., -1.0504, -0.7756, 0.5585],
[-0.3245, -1.4687, 1.2963, ..., -0.7141, 0.1782, 0.1674],
[-0.0669, -0.1363, 0.0956, ..., -0.2329, -0.2634, -2.0688],
[-0.0764, 0.7159, 0.0457, ..., 0.2907, 0.1509, 0.0606]]],
grad_fn=<AddBackward0>)
torch.Size([2, 4, 512]