Conventional approaches to robustness try to learn a model based on causal features. However, identifying maximally robust or causal features may be difficult in some scenarios, and in others, non-causal "shortcut" features may actually be more predictive. We propose a lightweight, sample-efficient approach that learns a diverse set of features and adapts to a target distribution by interpolating these features with a small target dataset. Our approach, Project and Probe (Pro$^2$), first learns a linear projection that maps a pre-trained embedding onto orthogonal directions while being predictive of labels in the source dataset. The goal of this step is to learn a variety of predictive features, so that at least some of them remain useful after distribution shift. Pro$^2$ then learns a linear classifier on top of these projected features using a small target dataset. We theoretically show that Pro$^2$ learns a projection matrix that is optimal for classification in an information-theoretic sense, resulting in better generalization due to a favorable bias-variance tradeoff. Our experiments on four datasets, with multiple distribution shift settings for each, show that Pro$^2$ improves performance by 5-15% when given limited target data compared to prior methods such as standard linear probing.
翻译:摘要:传统鲁棒性方法试图基于因果特征学习模型。然而,在某些场景下识别最大鲁棒性或因果特征可能较为困难,而在其他场景中,非因果的"捷径"特征可能更具预测能力。本文提出一种轻量化、样本高效的方法,通过学习多样化特征集,并利用小规模目标数据集通过插值这些特征来适应目标分布。我们的方法——Project and Probe (Pro$^2$)——首先学习一个线性投影,将预训练嵌入映射到正交方向,同时保持对源数据集标签的预测能力。此步骤旨在学习多种预测性特征,使得至少部分特征在分布偏移后仍具实用性。随后,Pro$^2$利用小规模目标数据集在这些投影特征之上训练线性分类器。我们从理论上证明,Pro$^2$学习到的投影矩阵在信息论意义上对分类是最优的,通过有利的偏差-方差权衡实现更好的泛化。我们在四个数据集(每个数据集含多种分布偏移设置)上的实验表明,与标准线性探测等已有方法相比,Pro$^2$在目标数据有限的条件下性能提升5-15%。