Spurious correlations pose a major challenge for robust machine learning. Models trained with empirical risk minimization (ERM) may learn to rely on correlations between class labels and spurious attributes, leading to poor performance on data groups without these correlations. This is particularly challenging to address when spurious attribute labels are unavailable. To improve worst-group performance on spuriously correlated data without training attribute labels, we propose Correct-N-Contrast (CNC), a contrastive approach to directly learn representations robust to spurious correlations. As ERM models can be good spurious attribute predictors, CNC works by (1) using a trained ERM model's outputs to identify samples with the same class but dissimilar spurious features, and (2) training a robust model with contrastive learning to learn similar representations for same-class samples. To support CNC, we introduce new connections between worst-group error and a representation alignment loss that CNC aims to minimize. We empirically observe that worst-group error closely tracks with alignment loss, and prove that the alignment loss over a class helps upper-bound the class's worst-group vs. average error gap. On popular benchmarks, CNC reduces alignment loss drastically, and achieves state-of-the-art worst-group accuracy by 3.6% average absolute lift. CNC is also competitive with oracle methods that require group labels.
翻译:伪相关性对鲁棒机器学习构成了重大挑战。使用经验风险最小化(ERM)训练的模型可能会学习依赖类别标签与伪属性之间的相关性,导致在不存在这些相关性的数据组上性能不佳。当伪属性标签不可用时,解决这一问题尤为困难。为了在无需训练属性标签的情况下,提升模型在具有伪相关性数据上的最差组性能,我们提出了Correct-N-Contrast(CNC),这是一种通过对比学习直接学习对伪相关性鲁棒的表示的方法。鉴于ERM模型可以是良好的伪属性预测器,CNC的工作原理是:(1)利用训练好的ERM模型的输出来识别类别相同但伪特征不同的样本;(2)通过对比学习训练一个鲁棒模型,以学习相同类别样本的相似表示。为了支撑CNC,我们建立了最差组误差与CNC旨在最小化的表示对齐损失之间的新联系。我们通过实验观察到,最差组误差与对齐损失紧密相关,并证明了一个类别上的对齐损失有助于上界该类别的最差组误差与平均误差之间的差距。在主流基准测试中,CNC大幅降低了对齐损失,并以平均3.6%的绝对提升实现了最先进的最差组准确率。CNC的性能也与需要组标签的“先知”方法具有竞争力。