Changes in the data distribution at test time can have deleterious effects on the performance of predictive models $p(y|x)$. We consider situations where there are additional meta-data labels (such as group labels), denoted by $z$, that can account for such changes in the distribution. In particular, we assume that the prior distribution $p(y, z)$, which models the dependence between the class label $y$ and the "nuisance" factors $z$, may change across domains, either due to a change in the correlation between these terms, or a change in one of their marginals. However, we assume that the generative model for features $p(x|y,z)$ is invariant across domains. We note that this corresponds to an expanded version of the widely used "label shift" assumption, where the labels now also include the nuisance factors $z$. Based on this observation, we propose a test-time label shift correction that adapts to changes in the joint distribution $p(y, z)$ using EM applied to unlabeled samples from the target domain distribution, $p_t(x)$. Importantly, we are able to avoid fitting a generative model $p(x|y, z)$, and merely need to reweight the outputs of a discriminative model $p_s(y, z|x)$ trained on the source distribution. We evaluate our method, which we call "Test-Time Label-Shift Adaptation" (TTLSA), on several standard image and text datasets, as well as the CheXpert chest X-ray dataset, and show that it improves performance over methods that target invariance to changes in the distribution, as well as baseline empirical risk minimization methods. Code for reproducing experiments is available at https://github.com/nalzok/test-time-label-shift .
翻译:测试时数据分布的变化会对预测模型 $p(y|x)$ 的性能产生不利影响。本文考虑存在额外元数据标签(如组标签,记为 $z$)的情形,这些标签可解释此类分布变化。具体而言,我们假设类标签 $y$ 与“干扰”因子 $z$ 之间的依赖关系由先验分布 $p(y, z)$ 建模,该分布可能因领域不同而改变——或源于这些项之间相关性的变化,或源于它们边缘分布的变化。然而,我们假设特征生成模型 $p(x|y,z)$ 在不同领域间保持恒定。值得注意的是,这对应于广泛使用的"标签偏移"假设的扩展版本,其中标签现在同时包含干扰因子 $z$。基于此观察,我们提出了一种测试时标签偏移校正方法,通过将 EM 算法应用于目标域分布 $p_t(x)$ 的未标注样本,自适应调整联合分布 $p(y, z)$ 的变化。重要的是,我们无需拟合生成模型 $p(x|y, z)$,仅需对源分布上训练的判别模型 $p_s(y, z|x)$ 的输出进行重加权。我们将该方法命名为"测试时标签偏移适应"(Test-Time Label-Shift Adaptation, TTLSA),并在多个标准图像与文本数据集以及 CheXpert 胸部 X 光数据集上进行评估。实验表明,相较于追求分布变更不变性的方法及基线经验风险最小化方法,所提方法显著提升了性能。实验复现代码已开源:https://github.com/nalzok/test-time-label-shift。