In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order and potentially in parallel. Generating an overall string sampled from the correct underlying joint distribution would (again) require exactly one token unmasking in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to {\em approximately} sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer ``sampler" on top of an existing large diffusion LM. One forward pass of the full model can now be followed by multiple forward passes of only this sampler layer, to yield multiple unmasked tokens. Our sampler is trained to mimic exact joint sampling from the (frozen) full model. We show the effectiveness of our approximate joint sampling for both pretrained-only (Dream-7B-Base, Llada-7B-Base) and instruction-tuned (Dream-7B-Instruct, Dream-7B-Coder) models on language modeling and math \& coding tasks. When four tokens are unmasked for each full-model denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) with respect to the true joint distribution.
翻译:在自回归语言模型中,每个词元都是基于所有历史词元进行条件采样生成的;因此整个字符串是从模型所表示的正确底层联合分布中采样的。相比之下,掩码扩散语言模型通过无序(且可能并行)地解除词元掩码来生成文本。要生成从正确底层联合分布采样的整体字符串,则(再次)需要在每次完整模型前向传播中精确解除一个词元的掩码。并行解除掩码的词元越多,生成的字符串距离真实联合分布就越远;这可以从结果中准确率的下降(但速度的提升)观察到。本文设计了一种方法,能够在单次完整模型前向传播中{\em 近似}地从联合分布采样多个词元;我们通过在现有大型扩散语言模型之上构建一个新的轻量级单层“采样器”来实现这一目标。现在,一次完整模型的前向传播后,可以仅对该采样器层进行多次前向传播,从而生成多个解除掩码的词元。我们的采样器经过训练,以模拟从(冻结的)完整模型进行精确联合采样的过程。我们在语言建模以及数学与编程任务上,展示了该方法在仅预训练模型(Dream-7B-Base、Llada-7B-Base)和指令微调模型(Dream-7B-Instruct、Dream-7B-Coder)上的有效性。当每个完整模型去噪步骤解除四个词元的掩码时,我们的采样算法相对于真实联合分布的MAUVE得分达到0.87(而边际基线为0.31)。