Understanding Gu2018NonAutoregressiveNM
Motivation
解码模型 | 训练循环个数 | 测试循环个数 | |
---|---|---|---|
Autoregressive | \(p_{\mathcal{A} \mathcal{R}}(Y \mid X ; \theta)=\prod_{t=1}^{T+1} p\left(y_{t} \mid y_{0: t-1}, x_{1: T^{\prime}} ; \theta\right)\) | \(n\) (RNN); \(1\) (Transformer) | \(n\) |
Non-autoregressive | \(p_{\mathcal{N} \mathcal{A}}(Y \mid X ; \theta)=p_{L}\left(T \mid x_{1: T^{\prime}} ; \theta\right) \cdot \prod_{t=1}^{T} p\left(y_{t} \mid x_{1: T^{\prime}} ; \theta\right)\) | \(1\) | \(1\) |
其中循环个数表示的是数据经Decoder操作的次数, 非组回归(Non-Autoregressive)模型可以大幅加速训练时的解码速度.
Model
对Transformer做了如下小幅改动:
- Encoder后接一个Fertility Predictor, 代表Decoder输入中原句的拷贝次数, 作者设定最大拷贝次数为50;
- 在Multi-Head Self-Attention层中, 不同于Transformer中mask了之后看到的句子, 这里不需要任何mask. 但作者mark了自身, 只能看到之前和之后的句子, 并指出这样效果更好;
- Decoder输入没有加位置编码, 而是把Q和K作为位置编码, 用在了Multi-Head Positional Attention中, 原文中作者指出这样做效果更好.
Training
直接训练整个模型是困难的, 因为使用Fertility构造Decoder输入时不可导的, 因此这里作者给出了极大似然概率的下界 (不等式成立时因为\(\log (a+b)\ge \log a+\log b)\), 当\(a,b\ge 1\)), 因此可以将Fertility Predictor和Decoder分开训练. \[ \begin{aligned} \mathcal{L}_{\mathrm{ML}} &=\log p_{\mathcal{N} \mathcal{A}}(Y \mid X ; \theta)=\log \sum_{f_{1: T^{\prime}} \in \mathcal{F}} p_{F}\left(f_{1: T^{\prime}} \mid x_{1: T^{\prime}} ; \theta\right) \cdot p\left(y_{1: T} \mid x_{1: T^{\prime}}, f_{1: T^{\prime}} ; \theta\right) \\ & \geq \underset{f_{1: T^{\prime}} \sim q}{\mathbb{E}}(\underbrace{\sum_{t=1}^{T} \log p\left(y_{t} \mid x_{1}\left\{f_{1}\right\}, . ., x_{T^{\prime}}\left\{f_{T^{\prime}}\right\} ; \theta\right)}_{\text {Translation Loss }}+\underbrace{\sum_{t^{\prime}=1}^{T^{\prime}} \log p_{F}\left(f_{t^{\prime}} \mid x_{1: T^{\prime}} ; \theta\right)}_{\text {Fertility Loss }})+\mathcal{H}(q) \end{aligned} \]
Fertility Predictor训练方式: 找一个外部的Text Aligner (如ab/fast_align), 定义其生成的分布为\(q\), 监督训练.
Decoder训练方式:
- 直接对语料库监督训练, 效果不好, 产生Multimodality Problem;
- 使用Knowledge Distillation.
Multimodality Problem
模型的输出是条件独立的, 作者在原文给了个很形象的例子, 并称之多模态问题 (Multimodality Problem).
Such a decoder is akin to a panel of human translators each asked to provide a single word of a translation independently of the words their colleagues choose.
Techniques
Noisy Parallel Decoding
解决确定输出长度的问题, 思想就是多输出几个不同长度的句子, 从里面挑一个最好的, 步骤如下:
- 对Fertility采样;
- 生成句子并用训练好的Autoregressive打分;
- 选一个分最高的作为输出.
形式上即 \[
\hat{Y}_{\mathrm{NPD}}=G\left(x_{1: T^{\prime}}, \underset{f_{t^{\prime}
\sim p_{F}}}{\operatorname{argmax}} p_{\mathcal{A R}}\left(G\left(x_{1:
T^{\prime}}, f_{1: T^{\prime}} ; \theta\right) \mid X ; \theta\right) ;
\theta\right)
\]
Knowledge Distillation
解决多模态问题, 步骤如下:
- 训练一个Autoregressive模型作为老师;
- 将老师Greedy解码的输出作为标签训练NAT.
Fine-Tuning
为了摆脱外部Text Aligner的依赖, 将Fertility Predictor和Decoder一起用强化学习训练, 生成reward时需要训练好的Autoregressive模型作为老师, 下面定义强化学习的各个要素.
State: Encoder输出;
Action: 原句生成Fertilities;
Reward: 老师和学生输出的相似度, 形式上定义为输出分布的Reverse K-L Divergence中的一项, 即 \[ \mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}} ; \theta\right)=\sum_{t=1}^{T} \sum_{y_{t}}\left[\log p_{\mathcal{A} \mathcal{R}}\left(y_{t} \mid \hat{y}_{1: t-1}, x_{1: T^{\prime}}\right) \cdot p_{\mathcal{N} \mathcal{A}}\left(y_{t} \mid x_{1: T^{\prime}}, f_{1: T^{\prime}} ; \theta\right)\right] \]
这个Reward是怎么来的呢? 作者在文中的解释比较模糊:
Such a loss is more favorable towards highly peaked student output distributions than a standard cross-entropy error would be.
以下参考Forward and Reverse KL Divergence给出详细的解释.
K-L Divergence的定义为 \[ D_{K L}(A \| B)=\sum_{i} p_{A}\left(v_{i}\right) \log p_{A}\left(v_{i}\right)-p_{A}\left(v_{i}\right) \log p_{B}\left(v_{i}\right) \] Entropy的定义如下, 最大时分布为均匀分布, 见Why is Entropy maximised when the probability distribution is uniform?. \[ \mathrm{H}(X)=-\sum_{i=1}^{n} \mathrm{P}\left(x_{i}\right) \log _{b} \mathrm{P}\left(x_{i}\right) \] 给定\(p_{\mathcal{AR}}\), 我们的目标是最小化\(p_{\mathcal{AR}}\)和\(p_{\mathcal{NA}}\)间的K-L Divergence, 那么我们有两种写法:
\[ \min D_{KL}(p_{\mathcal{AR}}||p_{\mathcal{NA}})\Longleftrightarrow\max\sum p_{\mathcal{AR}}\log p_{\mathcal{NA}} \] \[ \min D_{KL}(p_{\mathcal{NA}}||p_{\mathcal{AR}})\Longleftrightarrow \max\sum\bigg( p_{\mathcal{NA}}\log p_{\mathcal{AR}}- p_{\mathcal{NA}}\log p_{\mathcal{NA}}\bigg)\Longleftrightarrow \max\bigg(\sum p_{\mathcal{NA}}\log p_{\mathcal{AR}}+H(p_{\mathcal{NA}})\bigg) \]
对于相等的\(D_{KL}(p_{\mathcal{NA}}||p_{\mathcal{AR}})\), \(\sum\limits p_{\mathcal{NA}}\log p_{\mathcal{AR}}\)越大, \(H(p_{\mathcal{NA}})\)就越小, 而\(p_{\mathcal{NA}}\)不均匀是我们希望看到的情况, 因此最大化\(\sum\limits p_{\mathcal{NA}}\log p_{\mathcal{AR}}\)就够了. 本文使用REINFORCE算法进行训练, 损失函数如下 (这里是最大化问题) \[ \mathcal{L}_{\mathrm{FT}}=\lambda(\underbrace{\underset{f_{1: T^{\prime} \sim p_{F}}}{\mathbb{E}}\left(\mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}}\right)-\mathcal{L}_{\mathrm{RKL}}\left(\bar{f}_{1: T^{\prime}}\right)\right)}_{\mathcal{L}_{\mathrm{RL}}}+\underbrace{\underset{f_{1: T^{\prime} \sim q}}{\mathbb{E}}\left(\mathcal{L}_{\mathrm{RKL}}\left(f_{1: T^{\prime}}\right)\right)}_{\mathcal{L}_{\mathrm{BP}}})+(1-\lambda) \mathcal{L}_{\mathrm{KD}} \]
这里对其各项作解释如下:
- \(\mathcal{L}_{\mathrm{RL}}\): REINFORCE算法的更新式, 目的是训练Fertility Predictor;
- \(\mathcal{L}_{\mathrm{BP}}\): 对Text Aligner损失, 目的是让Fertility Predictor保持Text Aligner生成Fertilities的翻译能力;
- \(\mathcal{L}_{\mathrm{KD}}\): Knowledge Distillation损失, 即算损失时把标签替换为老师模型的输出 (?应该是这个意思, 原文没有给出具体表达式)
Experiments
以二倍的速度达到了Autoregressive模型近似的效果.
References
[DLHLP 2020] Non-Autoregressive Sequence Generation (由助教莊永松同學講授)
[TA 補充課] Network Compression (1/2): Knowledge Distillation (由助教劉俊緯同學講授)