Seq2Seq中的Exposure Bias现象的原因以及解决办法

参考资料

本文是下列资料的总结:

[1] 李宏毅视频 59:36 开始
[2] Seq2Seq中Exposure Bias现象的浅析与对策
[3] Bridging the Gap between Training and Inference for Neural Machine Translation(2019ACL最佳长论文)
[4] Self-critical Sequence Training

原因

Seq2Seq模型会遇到常说的Exposure Bias现象,原因是在训练阶段和预测阶段会遇到mismatch。训练阶段使用的是Teacher Forcing,也就是decoder在某时刻的输入是上一时刻的ground truth(真实标签)。然而在预测阶段只能使用上一时刻decoder的输出来作为这一时刻的输入,从而导致mismatch。

下图来自资料 [1] 李宏毅老师的视频:

在这里插入图片描述

在decoder的第二个时间步输入了模型在第一个时间步的输出B(预测错误),而不是在训练阶段能够拿到的真实标签A,下一个时间步就会到达没有经过充分训练或甚至完全没有探索过的结点,也就是decoder在训练阶段在第二个时间步只拟合了条件分布 P ( Y 2 X , Y 1 = A ) P(Y_2|X,Y_1=A) 而没有拟合好条件分布 P ( Y 2 X , Y 1 = B ) P(Y_2|X,Y_1=B) 。具体的例子可以见参考资料[2] 的“简单例子”一节。

解决办法

Scheduled Sampling

下图来自资料[1],73分钟附近。

在这里插入图片描述
也就是在第二个时间步开始,输入有一定的概率 p p 使用的是真实标签reference,有 1 p 1-p 的概率使用的是模型在上一个时间步的输出。而概率 p p 随着训练的进行应该逐渐衰减至0,最后就是完全使用模型的输出作为输入,这样就与预测阶段匹配了。

Sentence Level Oracle Word + Gumbel Noise

在参考资料 [3] 中,将采样自模型输出而作为下一个时间步的输入的词称为 oracle word.

他们提出的方法与 Scheduled Sampling 的整体思路一致,只不过对于 oracle word 的选择多了一些设计。

The oracle word should be a word similar to the ground truth or a synonym. Using different strategies will pro-duce a different oracle word.

论文作者认为,oracle word应该与ground truth词是同义词或近义词,然后论文给出了两种得到 oracle word 的方案:

1、Word-Level Oracle Word,这个就是 Scheduled Sampling 使用的方案,实际上等价于下一种方案使用beam-with=1.

2、Sentence-Level Oracle Word,使用 beam-width=k 的 Beam-Search 先得到 k 个候选 decoder 输出,然后根据所关注的指标(例如BLEU分数)来选出 k 个句子中分数最高的句子,将它的单词作为每一步的 oracle word.

注意:使用 Beam-Search 得到的 decoder 输出不一定和 Ground Truth 句子 y y^* 等长,所以需要对 Beam-Search 过程做一些修正:如果某一步的最高概率的词是结束符然而此时长度还不够 y | y^*| ,就选概率第二高的词;如果某一步产生完字符后长度就到达 y | y^*| 了,然而这一步的概率最高词不是结束符,就强制选择结束符

除了提出新的 oracle word 的选择方案,作者还对每一步采样 oracle word 的过程使用了 Gumbel-Softmax 技巧,从而引入了 Gumbel Noise,相当于一种正则化。

注:
1、原文指出Gumbel noise is “treated as a form of regular-ization”,“Gumbel-Max provides a efficient and robustway to sample from a categorical distribution”.
2、Gumbel Softmax是对Gumbel Max的近似,对于 Gumbel Max 的理解可以见 这篇博客,我自己也写过一个Demo来演示Gumbel Max的作用,地址: 戳这里

对抗训练

参考资料 [2] 作者认为,其实前面所述方案的原理在于给训练阶段引入了扰动,让模型在有扰动的情况下依然可以预测正确。所以作者提出了两种带来扰动的方案:

1、启发式的随机替换。50%的概率不做改变;50%的概率把输入序列中30%的词替换掉,替换对象为原目标序列的任意一个词。

2、梯度惩罚

注:至于为什么梯度惩罚等价于对抗训练,可以参考该作者的另一篇博客:对抗训练浅谈:意义、方法和思考(附Keras实现)

作者通过实验说明了这两种方法都有一定效果。

基于强化学习直接优化BLEU

见参考资料 [4],主要工作有MIXER 及其改进。本人对这部分没有深入了解,以后有需要再进一步学习。

发布了67 篇原创文章 · 获赞 27 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/xpy870663266/article/details/104827790