Text-to-image diffusion models can generate stunning visuals, yet they often fail at tasks children find trivial--like placing a dog to the right of a teddy bear rather than to the left. When combinations get more unusual--a giraffe above an airplane--these failures become even more pronounced. Existing methods attempt to fix these spatial reasoning failures through model fine-tuning or test-time optimization with handcrafted losses that are suboptimal. Rather than imposing our assumptions about spatial encoding, we propose learning these objectives directly from the model's internal representations. We introduce Learn-to-Steer, a novel framework that learns data-driven objectives for test-time optimization rather than handcrafting them. Our key insight is to train a lightweight classifier that decodes spatial relationships from the diffusion model's cross-attention maps, then deploy this classifier as a learned loss function during inference. Training such classifiers poses a surprising challenge: they can take shortcuts by detecting linguistic traces in the cross-attention maps, rather than learning true spatial patterns. We solve this by augmenting our training data with samples generated using prompts with incorrect relation words, which encourages the classifier to avoid linguistic shortcuts and learn spatial patterns from the attention maps. Our method dramatically improves spatial accuracy: from 20% to 61% on FLUX.1-dev and from 7% to 54% on SD2.1 across standard benchmarks. It also generalizes to multiple relations with significantly improved accuracy.
翻译:文本到图像扩散模型能够生成令人惊叹的视觉内容,但在处理儿童都能轻松完成的任务时却常常失败——例如将小狗放置在泰迪熊的右侧而非左侧。当组合关系变得更加非常规时——例如长颈鹿位于飞机上方——这些失败会变得更加明显。现有方法试图通过模型微调或使用手工设计但非最优的损失函数进行测试时优化来修正这些空间推理错误。我们并未强加关于空间编码的先验假设,而是提出直接从模型内部表示中学习这些优化目标。我们引入了Learn-to-Steer这一新颖框架,通过数据驱动的方式学习测试时优化目标,而非依赖手工设计。我们的核心洞见是训练一个轻量级分类器,使其能够从扩散模型的交叉注意力图中解码空间关系,随后在推理阶段将该分类器作为习得的损失函数进行部署。训练此类分类器面临着一个意外的挑战:它们可能通过检测交叉注意力图中的语言痕迹来走捷径,而非真正学习空间模式。我们通过使用包含错误关系词的提示所生成的样本来增强训练数据,从而解决了这一问题,这促使分类器避免语言捷径,真正从注意力图中学习空间模式。我们的方法显著提升了空间准确性:在FLUX.1-dev数据集上从20%提升至61%,在SD2.1数据集上从7%提升至54%(基于标准基准测试)。该方法还能泛化至多种关系类型,并显著提升准确率。