摘要:GAN采用判别模型引导生成模型的训练在连续型数据上已经产生了很好的效果,但是有两个limitations,第一,当目标是离散数据时,如文本,不可能文本+1产生梯度信息引导生成器的生成;第二,判别模型只能对完整的序列产生判别信息,对于非完整序列,它并不知道当前的判别结果和未来完整序列的判别结果是否相同。SeqGAN可以解决这两个问题。采用强化学习的reward思想,实行梯度策略更新解决生成器的微分问题,即解决了第一个问题,采用Monte Carlo search将不完整的序列补充完整解决第二个问题。
SeqGAN:
给定真实序列数据集,训练产生序列,,是词汇表。在第步,状态是当前生成的序列, 行为是下一个选择的token,因此策略模型是随机的。此外,训练给提供引导,代表序列是否为真实序列的概率。正样本由真实数据采样得到,负样本为生成器生成的数据,采样正样本和负样本一起训练判别模型。同时,通过策略梯度和从判别模型得到的the expected end reward的基础上进行MC搜索更新生成模型,通过似然函数评估reward。
SeqGAN具体细节如下:
生成器模型(策略)目的是从开始状态产生一个序列最大化the expected end reward。
式中,表示完整序列的reward,这个reward是来自于判别模型,是一个序列的action-value function。即,这个期望累积reward从初始状态开始,采取行动,遵循策略。因此生成器的目标就是从初始状态开始,产生一个序列,让判别器判别它是真实的,即reward越大越好。
下一个问题是如何计算,该篇论文采用增强学习的思路,并且考虑通过判别器评估概率,作为reward。
然而判别器只对一个完整的序列提供reward值,因此采用MC搜索with a roll-out策略来采样未知的个tokens,表示N次MC搜索为
即产生N个序列。和通过roll-out策略和当前状态采样。在该篇论文中与生成器相同。为了减少方差并且得到更精切的action-value值的评估,从当前状态搭到整个序列完成运行roll-out策略N次得到输出样本的一个批次,因此,我们有
采用作为reward的优点是可以动态更新,进一步迭代地提高生成器模型。一旦我们有更多的真实序列,我们会重新训练判别器模型:
当我们有了判别器模型,我们就准备更新生成器。当action确定,状态转移是确定的,即,下一个状态,行动,另外其他的状态。此外当前的reward是,重写action value如下:
对于初始状态,
目标函数的梯度如下:
上面的形式是确定性的状态转移,即是固定的,以下采用似然比例建立一个无偏的估计。
期望可以通过采样方法近似,因此更新生成器的参数:
整个算法如下: