Text-to-image diffusion models have recently emerged at the forefront of image generation, powered by very large-scale unsupervised or weakly supervised text-to-image training datasets. Due to their unsupervised training, controlling their behavior in downstream tasks, such as maximizing human-perceived image quality, image-text alignment, or ethical image generation, is difficult. Recent works finetune diffusion models to downstream reward functions using vanilla reinforcement learning, notorious for the high variance of the gradient estimators. In this paper, we propose AlignProp, a method that aligns diffusion models to downstream reward functions using end-to-end backpropagation of the reward gradient through the denoising process. While naive implementation of such backpropagation would require prohibitive memory resources for storing the partial derivatives of modern text-to-image models, AlignProp finetunes low-rank adapter weight modules and uses gradient checkpointing, to render its memory usage viable. We test AlignProp in finetuning diffusion models to various objectives, such as image-text semantic alignment, aesthetics, compressibility and controllability of the number of objects present, as well as their combinations. We show AlignProp achieves higher rewards in fewer training steps than alternatives, while being conceptually simpler, making it a straightforward choice for optimizing diffusion models for differentiable reward functions of interest. Code and Visualization results are available at https://align-prop.github.io/.
翻译:文本到图像扩散模型近期凭借超大规模无监督或弱监督的文本到图像训练数据集,在图像生成领域处于前沿地位。由于其无监督的训练方式,在下游任务中控制模型行为——例如最大化人类感知的图像质量、图文对齐或符合伦理的图像生成——变得十分困难。现有研究通常采用传统强化学习方法对扩散模型进行下游奖励函数的微调,但该方法因梯度估计器的高方差而备受诟病。本文提出AlignProp方法,该方法通过去噪过程的端到端奖励梯度反向传播,实现扩散模型与下游奖励函数的对齐。尽管此类反向传播的朴素实现需要存储现代文本到图像模型偏导数的海量内存资源,但AlignProp通过微调低秩适配器权重模块并采用梯度检查点技术,使其内存使用量保持在可行范围内。我们在多个优化目标上测试AlignProp的微调效果,包括图文语义对齐、美学质量、图像可压缩性、生成对象数量的可控性及其组合目标。实验表明,AlignProp能以更少的训练步数获得更高的奖励值,且概念更为简洁,使其成为针对可微分目标奖励函数优化扩散模型的直接选择。代码与可视化结果详见 https://align-prop.github.io/。