The slow, sequential nature of autoregressive (AR) language models has driven the adoption of parallel decoding methods. However, these non-AR models often sacrifice generation quality as they struggle to model the complex joint distribution of token sequences. To narrow this performance gap, we introduce Gumbel Distillation, a novel distillation technique that enables parallel decoders to learn this distribution effectively. Our method leverages the Gumbel-Max trick to create a deterministic mapping from a latent Gumbel noise space to the output tokens of a high-performing AR teacher. As a model-agnostic technique, Gumbel Distillation seamlessly integrates with diverse parallel decoding architectures, including MDLM and BD3-LM. Experiments on LM1B and OpenWebText show that Gumbel Distillation substantially improves the generation quality of parallel language models, achieving a 30.0% improvement in MAUVE score and 10.5% in generative perplexity over MDLM trained on OpenWebText dataset. Code available at https://github.com/hxixixh/gumbel-distill.
翻译:自回归语言模型缓慢的顺序生成特性推动了并行解码方法的采用。然而,这些非自回归模型往往因难以对令牌序列的复杂联合分布进行建模而牺牲生成质量。为缩小这一性能差距,我们提出Gumbel蒸馏——一种使并行解码器能够有效学习该分布的新型蒸馏技术。该方法利用Gumbel-Max技巧,将潜在Gumbel噪声空间到高性能自回归教师模型输出令牌的映射转化为确定性映射。作为一种与模型无关的技术,Gumbel蒸馏能够无缝集成多种并行解码架构,包括MDLM和BD3-LM。在LM1B和OpenWebText数据集上的实验表明,Gumbel蒸馏显著提升了并行语言模型的生成质量:在OpenWebText数据集上训练的MDLM中,MAUVE评分提升30.0%,生成困惑度降低10.5%。代码已开源至https://github.com/hxixixh/gumbel-distill。