Diffusion Probability Models (DPMs) have made impressive advancements in various machine learning domains. However, achieving high-quality synthetic samples typically involves performing a large number of sampling steps, which impedes the possibility of real-time sample synthesis. Traditional accelerated sampling algorithms via knowledge distillation rely on pre-trained model weights and discrete time step scenarios, necessitating additional training sessions to achieve their goals. To address these issues, we propose the Catch-Up Distillation (CUD), which encourages the current moment output of the velocity estimation model ``catch up'' with its previous moment output. Specifically, CUD adjusts the original Ordinary Differential Equation (ODE) training objective to align the current moment output with both the ground truth label and the previous moment output, utilizing Runge-Kutta-based multi-step alignment distillation for precise ODE estimation while preventing asynchronous updates. Furthermore, we investigate the design space for CUDs under continuous time-step scenarios and analyze how to determine the suitable strategies. To demonstrate CUD's effectiveness, we conduct thorough ablation and comparison experiments on CIFAR-10, MNIST, and ImageNet-64. On CIFAR-10, we obtain a FID of 2.80 by sampling in 15 steps under one-session training and the new state-of-the-art FID of 3.37 by sampling in one step with additional training. This latter result necessitated only 620k iterations with a batch size of 128, in contrast to Consistency Distillation, which demanded 2100k iterations with a larger batch size of 256. Our code is released at https://anonymous.4open.science/r/Catch-Up-Distillation-E31F.
翻译:扩散概率模型(DPMs)在多个机器学习领域取得了令人瞩目的进展。然而,要生成高质量的合成样本通常需要进行大量采样步骤,这阻碍了实时样本合成的可能性。传统的基于知识蒸馏的加速采样算法依赖于预训练模型权重和离散时间步场景,需要额外的训练阶段才能实现其目标。为解决这些问题,我们提出了追赶蒸馏(CUD),该方法鼓励速度估计模型在当前时刻的输出“追赶”其前一时刻的输出。具体而言,CUD通过调整原始常微分方程(ODE)的训练目标,使当前时刻输出与真实标签及前一时刻输出对齐,并利用基于龙格-库塔法的多步对齐蒸馏来实现精确的ODE估计,同时避免异步更新。此外,我们探索了连续时间步场景下CUD的设计空间,并分析了如何确定合适的策略。为验证CUD的有效性,我们在CIFAR-10、MNIST和ImageNet-64数据集上进行了详尽的消融实验与对比实验。在CIFAR-10数据集上,通过单阶段训练下的15步采样,我们获得了2.80的FID分数;而通过额外训练后的单步采样,我们取得了3.37的FID分数,创造了新的最优记录。后一结果仅需62万次迭代(批大小为128),而一致性蒸馏方法则需要210万次迭代(批大小为256)。我们的代码已发布于https://anonymous.4open.science/r/Catch-Up-Distillation-E31F。