Understanding Gu2018NonAutoregressiveNM

Gu, J., Bradbury, J., Xiong, C., Li, V.O., & Socher, R. (2018). Non-Autoregressive Neural Machine Translation. ArXiv, abs/1711.02281.阅读笔记.

Motivation

解码模型 训练循环个数 测试循环个数
Autoregressive \(p_{\mathcal{A} \mathcal{R}}(Y \mid X ; \theta)=\prod_{t=1}^{T+1} p\left(y_{t} \mid y_{0: t-1}, x_{1: T^{\prime}} ; \theta\right)\) \(n\) (RNN); \(1\) (Transformer) \(n\)
Non-autoregressive \(p_{\mathcal{N} \mathcal{A}}(Y \mid X ; \theta)=p_{L}\left(T \mid x_{1: T^{\prime}} ; \theta\right) \cdot \prod_{t=1}^{T} p\left(y_{t} \mid x_{1: T^{\prime}} ; \theta\right)\) \(1\) \(1\)

其中循环个数表示的是数据经Decoder操作的次数, 非组回归(Non-Autoregressive)模型可以大幅加速训练时的解码速度.

Model

对Transformer做了如下小幅改动:

  1. Encoder后接一个Fertility Predictor, 代表Decoder输入中原句的拷贝次数, 作者设定最大拷贝次数为50;
  2. 在Multi-Head Self-Attention层中, 不同于Transformer中mask了之后看到的句子, 这里不需要任何mask. 但作者mark了自身, 只能看到之前和之后的句子, 并指出这样效果更好;
  3. Decoder输入没有加位置编码, 而是把Q和K作为位置编码, 用在了Multi-Head Positional Attention中, 原文中作者指出这样做效果更好.

Training

  • 直接训练整个模型是困难的, 因为使用Fertility构造Decoder输入时不可导的, 因此这里作者给出了极大似然概率的下界 (不等式成立时因为\(\log (a+b)\ge \log a+\log b)\), 当\(a,b\ge 1\)), 因此可以将Fertility Predictor和Decoder分开训练. \[ \begin{aligned} \mathcal{L}_{\mathrm{ML}} &=\log p_{\mathcal{N} \mathcal{A}}(Y \mid X ; \theta)=\log \sum_{f_{1: T^{\prime}} \in \mathcal{F}} p_{F}\left(f_{1: T^{\prime}} \mid x_{1: T^{\prime}} ; \theta\right) \cdot p\left(y_{1: T} \mid x_{1: T^{\prime}}, f_{1: T^{\prime}} ; \theta\right) \\ & \geq \underset{f_{1: T^{\prime}} \sim q}{\mathbb{E}}(\underbrace{\sum_{t=1}^{T} \log p\left(y_{t} \mid x_{1}\left\{f_{1}\right\}, . ., x_{T^{\prime}}\left\{f_{T^{\prime}}\right\} ; \theta\right)}_{\text {Translation Loss }}+\underbrace{\sum_{t^{\prime}=1}^{T^{\prime}} \log p_{F}\left(f_{t^{\prime}} \mid x_{1: T^{\prime}} ; \theta\right)}_{\text {Fertility Loss }})+\mathcal{H}(q) \end{aligned} \]

  • Fertility Predictor训练方式: 找一个外部的Text Aligner (如ab/fast_align), 定义其生成的分布为\(q\), 监督训练.

  • Decoder训练方式:

    • 直接对语料库监督训练, 效果不好, 产生Multimodality Problem;
    • 使用Knowledge Distillation.

Multimodality Problem

模型的输出是条件独立的, 作者在原文给了个很形象的例子, 并称之多模态问题 (Multimodality Problem).

Such a decoder is akin to a panel of human translators each asked to provide a single word of a translation independently of the words their colleagues choose.

Techniques

Noisy Parallel Decoding

解决确定输出长度的问题, 思想就是多输出几个不同长度的句子, 从里面挑一个最好的, 步骤如下:

  1. 对Fertility采样;
  2. 生成句子并用训练好的Autoregressive打分;
  3. 选一个分最高的作为输出.

形式上即 \[ \hat{Y}_{\mathrm{NPD}}=G\left(x_{1: T^{\prime}}, \underset{f_{t^{\prime} \sim p_{F}}}{\operatorname{argmax}} p_{\mathcal{A R}}\left(G\left(x_{1: T^{\prime}}, f_{1: T^{\prime}} ; \theta\right) \mid X ; \theta\right) ; \theta\right) \]

Knowledge Distillation

解决多模态问题, 步骤如下:

  1. 训练一个Autoregressive模型作为老师;
  2. 将老师Greedy解码的输出作为标签训练NAT.

Fine-Tuning

为了摆脱外部Text Aligner的依赖, 将Fertility Predictor和Decoder一起用强化学习训练, 生成reward时需要训练好的Autoregressive模型作为老师, 下面定义强化学习的各个要素.

  • State: Encoder输出;

  • Action: 原句生成Fertilities;

  • Reward: 老师和学生输出的相似度, 形式上定义为输出分布的Reverse K-L Divergence中的一项, 即 \[ \mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}} ; \theta\right)=\sum_{t=1}^{T} \sum_{y_{t}}\left[\log p_{\mathcal{A} \mathcal{R}}\left(y_{t} \mid \hat{y}_{1: t-1}, x_{1: T^{\prime}}\right) \cdot p_{\mathcal{N} \mathcal{A}}\left(y_{t} \mid x_{1: T^{\prime}}, f_{1: T^{\prime}} ; \theta\right)\right] \]

这个Reward是怎么来的呢? 作者在文中的解释比较模糊:

Such a loss is more favorable towards highly peaked student output distributions than a standard cross-entropy error would be.

以下参考Forward and Reverse KL Divergence给出详细的解释.

K-L Divergence的定义为 \[ D_{K L}(A \| B)=\sum_{i} p_{A}\left(v_{i}\right) \log p_{A}\left(v_{i}\right)-p_{A}\left(v_{i}\right) \log p_{B}\left(v_{i}\right) \] Entropy的定义如下, 最大时分布为均匀分布, 见Why is Entropy maximised when the probability distribution is uniform?. \[ \mathrm{H}(X)=-\sum_{i=1}^{n} \mathrm{P}\left(x_{i}\right) \log _{b} \mathrm{P}\left(x_{i}\right) \] 给定\(p_{\mathcal{AR}}\), 我们的目标是最小化\(p_{\mathcal{AR}}\)\(p_{\mathcal{NA}}\)间的K-L Divergence, 那么我们有两种写法:

\[ \min D_{KL}(p_{\mathcal{AR}}||p_{\mathcal{NA}})\Longleftrightarrow\max\sum p_{\mathcal{AR}}\log p_{\mathcal{NA}} \] \[ \min D_{KL}(p_{\mathcal{NA}}||p_{\mathcal{AR}})\Longleftrightarrow \max\sum\bigg( p_{\mathcal{NA}}\log p_{\mathcal{AR}}- p_{\mathcal{NA}}\log p_{\mathcal{NA}}\bigg)\Longleftrightarrow \max\bigg(\sum p_{\mathcal{NA}}\log p_{\mathcal{AR}}+H(p_{\mathcal{NA}})\bigg) \]

对于相等的\(D_{KL}(p_{\mathcal{NA}}||p_{\mathcal{AR}})\), \(\sum\limits p_{\mathcal{NA}}\log p_{\mathcal{AR}}\)越大, \(H(p_{\mathcal{NA}})\)就越小, 而\(p_{\mathcal{NA}}\)不均匀是我们希望看到的情况, 因此最大化\(\sum\limits p_{\mathcal{NA}}\log p_{\mathcal{AR}}\)就够了. 本文使用REINFORCE算法进行训练, 损失函数如下 (这里是最大化问题) \[ \mathcal{L}_{\mathrm{FT}}=\lambda(\underbrace{\underset{f_{1: T^{\prime} \sim p_{F}}}{\mathbb{E}}\left(\mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}}\right)-\mathcal{L}_{\mathrm{RKL}}\left(\bar{f}_{1: T^{\prime}}\right)\right)}_{\mathcal{L}_{\mathrm{RL}}}+\underbrace{\underset{f_{1: T^{\prime} \sim q}}{\mathbb{E}}\left(\mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}}\right)\right)}_{\mathcal{L}_{\mathrm{BP}}})+(1-\lambda) \mathcal{L}_{\mathrm{KD}} \]

这里对其各项作解释如下:

  1. \(\mathcal{L}_{\mathrm{RL}}\): REINFORCE算法的更新式, 目的是训练Fertility Predictor;
  2. \(\mathcal{L}_{\mathrm{BP}}\): 对Text Aligner损失, 目的是让Fertility Predictor保持Text Aligner生成Fertilities的翻译能力;
  3. \(\mathcal{L}_{\mathrm{KD}}\): Knowledge Distillation损失, 即算损失时把标签替换为老师模型的输出 (?应该是这个意思, 原文没有给出具体表达式)

Experiments

以二倍的速度达到了Autoregressive模型近似的效果.

References

  • [DLHLP 2020] Non-Autoregressive Sequence Generation (由助教莊永松同學講授)

  • [TA 補充課] Network Compression (1/2): Knowledge Distillation (由助教劉俊緯同學講授)