Diffusion Probability Models (DPMs) have made impressive advancements in various machine learning domains. However, achieving high-quality synthetic samples typically involves performing a large number of sampling steps, which impedes the possibility of real-time sample synthesis. Traditional accelerated sampling algorithms via knowledge distillation rely on pre-trained model weights and discrete time step scenarios, necessitating additional training sessions to achieve their goals. To address these issues, we propose the Catch-Up Distillation (CUD), which encourages the current moment output of the velocity estimation model ``catch up'' with its previous moment output. Specifically, CUD adjusts the original Ordinary Differential Equation (ODE) training objective to align the current moment output with both the ground truth label and the previous moment output, utilizing Runge-Kutta-based multi-step alignment distillation for precise ODE estimation while preventing asynchronous updates. Furthermore, we investigate the design space for CUDs under continuous time-step scenarios and analyze how to determine the suitable strategies. To demonstrate CUD's effectiveness, we conduct thorough ablation and comparison experiments on CIFAR-10, MNIST, and ImageNet-64. On CIFAR-10, we obtain a FID of 2.80 by sampling in 15 steps under one-session training and the new state-of-the-art FID of 3.37 by sampling in one step with additional training. This latter result necessitated only 620k iterations with a batch size of 128, in contrast to Consistency Distillation, which demanded 2100k iterations with a larger batch size of 256. Our code is released at https://anonymous.4open.science/r/Catch-Up-Distillation-E31F.
翻译:扩散概率模型(DPMs)已在多个机器学习领域取得显著进展。然而,生成高质量合成样本通常需要执行大量采样步骤,这阻碍了实时样本合成的可能性。传统的基于知识蒸馏的加速采样算法依赖预训练模型权重和离散时间步场景,需要额外的训练阶段才能实现目标。为解决这些问题,我们提出追赶式蒸馏(CUD),该方法促使速度估计模型的当前时刻输出“追赶”其前一时刻输出。具体而言,CUD调整原始常微分方程(ODE)训练目标,使当前时刻输出同时与真实标签和前一时刻输出保持一致,并采用基于龙格-库塔的多步对齐蒸馏实现精确的ODE估计,同时防止异步更新。此外,我们研究了连续时间步场景下CUD的设计空间,分析了如何确定合适策略。为验证CUD有效性,我们在CIFAR-10、MNIST和ImageNet-64上进行了全面的消融与对比实验。在CIFAR-10上,通过单次训练以15步采样获得2.80的FID,并通过额外训练以单步采样获得当前最优的3.37 FID。后者仅需620k次迭代且批大小为128,而一致性蒸馏需要2100k次迭代及更大的批大小256。我们的代码已发布于https://anonymous.4open.science/r/Catch-Up-Distillation-E31F。