前言
论文地址:https://arxiv.org/abs/1905.02450
代码地址:https://github.com/microsoft/MASS
前人工作&存在问题
预训练+微调可以缓解特定下游任务语料不足的缺陷,如ELMO\GPT\BERT。但是BERT模型是为NLU任务设计的。
也有一些为NLG任务设计的预训练模型,如:
- 利用一个语言模型或者自编码器来预训练encoder和decoder(效果没有BERT好);
- 设计了一个句子重排任务(只为encoder做预训练);
- XLM为encoder decoder单独训练一个BERT(attention模块得不到预训练、不适用于NLG的任务);
本文贡献
本文设计了一个基于mask的,但能够直接预训练整个encoder-decoder架构的预训练模式,在NLG下游任务上得到了更好的结果。
具体方法
模型架构如图1所示,通过被mask掉的连续分段长度k来和标准语言模型(GPT\ELMO)、MASK语言模型(BERT)进行比较:
- 当k=1时:可以理解为decoder所需的信息全都来自于encoder的双向信息。具体来说,decoder唯一的mask位置通过CAN聚合了所有上下文信息,decoder也就可以被理解为是BERT encoder上的softmax分类层。
- 当k=m时:encoder端提供不了任何有用信息,decoder就是一个language model。
- 当1< k >m时:
- encoder端必须要深刻理解整体的语义(like encoder in NMT)
- decoder端必须充分利用encoder的信息,同时也可以发挥language model的优势,即利用之前输入的信息(like decoder in NMT)
- decoder端想要预测的内容,decoder端没有提供完整的上下文,encoder端提供了,又没有直接提供对应的内容,所以decoder既要充分利用encoder的上下文信息,也要部分利用之前的输入信息(1. like NMT;2.和DAE预训练encoder-decoder相比之下的优势(DAE的decoder端要预测的东西,可以直接从encoder端copy过来,而copy机制由于transformer中的残差网络显得更加容易))
另外,具体实现时,k一般取50%。decoder端被mask的token不会作为输入,假如原本输入是【mask1,mask2, token1, token2, … t o k e n n token_n tokenn, mask3, mask4】,那么只需要输入【mask2(用于预测token1),token1,… t o k e n n − 1 token_{n-1} tokenn−1】就可以了,节约了一半的计算。
具体实验
UNMT实验结果如何?
如图3中,方法有:
- 不知
- 双语词典+back-translation
- 不知
- (combined BPE + fasttext)+(DAE+back-translation)
- (combined BPE + MLM\CLM)+(DAE+back-translation)
- (combined BPE + MASS)+back-translation
如图4中,方法有:
- BERT+LM:用BERT来预训练encoder,language-model来预训练decoder,CAN没得训练;
- DAE:直接用DAE来对encoder-decoder做预训练;
如图3、4,MASS,好!
Low-resource NMT结果如何?
如图5所示,在不同的数据量上,预训练过的MASS都比没有经过预训练的baseline效果更好。
超参数k的实验
k=50%的效果最好,k太小,decoder的生成能力弱;k太大,decoder退化成一个standard language model,虽然拥有很强的生成能力,但不会充分利用encoder端的信息。(个人理解)
MASS的两个变种:1. 离散的mask 2. decoder端不进行mask
如图6所示,结果表明:
- 离散的mask下降不多(待探究)
- decoder端不进行mask的下降较大,如果不进行mask,则decoder端退化为language model,不太关注encoder的信息(这也可能是传统NMT训练方式的缺点,待探究)
- 通过图6的Feed可知:DAE用于训练encoder+decoder和MASS的decoder端不进行mask的唯一差别在于,DAE的encoder端做了一些重排、删除、mask!,而MASS的encoder端就做了连续的 mask,也就说明了连续mask的有效性?
心得总结
- encoder-decoder用于NMT时,decoder端的输入也是完整的target句子,这时decoder会不会也存在,不太利用到encoder端信息的问题?,极端点讲,退化成一个language model。所以训练一般的NMT时,可不可以对decoder端进行一些mask?