Diffusion Probability Models (DPMs) have made impressive advancements in various machine learning domains. However, achieving high-quality synthetic samples typically involves performing a large number of sampling steps, which impedes the possibility of real-time sample synthesis. Traditional accelerated sampling algorithms via knowledge distillation rely on pre-trained model weights and discrete time step scenarios, necessitating additional training sessions to achieve their goals. To address these issues, we propose the Catch-Up Distillation (CUD), which encourages the current moment output of the velocity estimation model ``catch up'' with its previous moment output. Specifically, CUD adjusts the original Ordinary Differential Equation (ODE) training objective to align the current moment output with both the ground truth label and the previous moment output, utilizing Runge-Kutta-based multi-step alignment distillation for precise ODE estimation while preventing asynchronous updates. Furthermore, we investigate the design space for CUDs under continuous time-step scenarios and analyze how to determine the suitable strategies. To demonstrate CUD's effectiveness, we conduct thorough ablation and comparison experiments on CIFAR-10, MNIST, and ImageNet-64. On CIFAR-10, we obtain a FID of 2.80 by sampling in 15 steps under one-session training and the new state-of-the-art FID of 3.37 by sampling in one step with additional training. This latter result necessitated only 620k iterations with a batch size of 128, in contrast to Consistency Distillation, which demanded 2100k iterations with a larger batch size of 256. Our code is released at https://anonymous.4open.science/r/Catch-Up-Distillation-E31F.
翻译:扩散概率模型(DPMs)已在多个机器学习领域取得显著进展。然而,生成高质量合成样本通常需要执行大量采样步骤,这阻碍了实时样本合成的可能性。传统的基于知识蒸馏的加速采样算法依赖于预训练模型权重和离散时间步场景,需要额外的训练阶段才能实现目标。为解决这些问题,我们提出追赶蒸馏(Catch-Up Distillation, CUD),该方法促使速度估计模型当前时刻的输出“追赶”其上一时刻的输出。具体而言,CUD调整原始常微分方程(ODE)训练目标,使当前时刻输出同时对齐真实标签和上一时刻输出,利用基于龙格-库塔的多步对齐蒸馏实现精确的ODE估计,同时防止异步更新。此外,我们探索了连续时间步场景下CUD的设计空间,并分析了如何确定合适的策略。为验证CUD的有效性,我们在CIFAR-10、MNIST和ImageNet-64上进行了全面的消融与对比实验。在CIFAR-10上,通过单次训练、15步采样获得2.80的FID分数,而通过额外训练、单步采样获得新的最佳FID分数3.37——后者仅需620k次迭代(批大小为128),相比之下,一致性蒸馏(Consistency Distillation)则需要2100k次迭代(更大批大小256)。我们的代码已发布至 https://anonymous.4open.science/r/Catch-Up-Distillation-E31F。