Spurious correlations, or correlations that change across domains where a model can be deployed, present significant challenges to real-world applications of machine learning models. However, such correlations are not always "spurious"; often, they provide valuable prior information for a prediction. Here, we present a test-time adaptation method that exploits the spurious correlation phenomenon, in contrast to recent approaches that attempt to eliminate spurious correlations through invariance. We consider situations where the prior distribution $p(y, z)$, which models the dependence between the class label $y$ and the "nuisance" factors $z$, may change across domains, but the generative model for features $p(\mathbf{x}|y, z)$ is constant. We note that this corresponds to an expanded version of the label shift assumption, where the labels now also include the nuisance factors $z$. Based on this observation, we train a classifier to predict $p(y, z|\mathbf{x})$ on the source distribution, and propose a test-time label shift correction that adapts to changes in the marginal distribution $p(y, z)$ using unlabeled samples from the target domain. We evaluate our method, which we call "Test-Time Label-Shift Adaptation" (TTLSA), on two different image datasets -- the CheXpert chest X-ray dataset and the Colored MNIST dataset -- and show a significant improvement over baseline methods. Code reproducing experiments is available at https://github.com/nalzok/test-time-label-shift .
翻译:伪相关——即在不同部署域中发生变化的关联关系——给机器学习模型的实际应用带来了重大挑战。然而,这类关联并非总是“虚假的”;通常,它们为预测提供了有价值的先验信息。本文提出一种测试时自适应方法,利用伪相关现象,这与近期试图通过不变性消除伪相关的方法形成对比。我们考虑这样的场景:先验分布$p(y, z)$(刻画类别标签$y$与“干扰”因素$z$之间的依赖关系)可能在不同域间发生变化,但特征生成模型$p(\mathbf{x}|y, z)$保持恒定。我们注意到,这对应标签偏移假设的一个扩展版本——其中标签现在也包含干扰因素$z$。基于此观察,我们训练一个分类器在源分布上预测$p(y, z|\mathbf{x})$,并提出一种测试时标签偏移校正方法,利用目标域的无标签样本自适应$p(y, z)$边际分布的变化。我们在两个不同图像数据集——CheXpert胸部X光数据集和Colored MNIST数据集——上评估了该方法(称为“测试时标签偏移自适应”,TTLSA),并展示了相比基线方法的显著提升。重现实验的代码可在https://github.com/nalzok/test-time-label-shift获取。