Changes in the data distribution at test time can have deleterious effects on the performance of predictive models $p(y|x)$. We consider situations where there are additional meta-data labels (such as group labels), denoted by $z$, that can account for such changes in the distribution. In particular, we assume that the prior distribution $p(y, z)$, which models the dependence between the class label $y$ and the "nuisance" factors $z$, may change across domains, either due to a change in the correlation between these terms, or a change in one of their marginals. However, we assume that the generative model for features $p(x|y, z)$ is invariant across domains. We note that this corresponds to an expanded version of the widely used "label shift" assumption, where the labels now also include the nuisance factors $z$. Based on this observation, we propose a test-time label shift correction that adapts to changes in the joint distribution $p(y, z)$ using EM applied to unlabeled samples from the target domain distribution, $p_t(x)$. Importantly, we are able to avoid fitting a generative model $p(x|y,z)$, and merely need to reweight the outputs of a discriminative model $p_s(y,z|x)$ trained on the source distribution. We evaluate our method, which we call "Test-Time Label-Shift Adaptation" (TTLSA), on several standard image and text datasets, as well as the CheXpert chest X-ray dataset, and show that it improves performance over methods that target invariance to changes in the distribution, as well as baseline empirical risk minimization methods. Code for reproducing experiments is available at https://github.com/nalzok/test-time-label-shift .
翻译:测试时数据分布的变化可能对预测模型 $p(y|x)$ 的性能产生不利影响。我们考虑存在附加元数据标签(例如组标签),记为 $z$,的情形,这些标签可以解释分布中的此类变化。具体而言,我们假设先验分布 $p(y, z)$(用于建模类别标签 $y$ 与“干扰”因素 $z$ 之间的依赖关系)可能在不同领域间发生变化,这种变化可能源于这些项之间相关性的改变,或其中某一项边缘分布的改变。然而,我们假设特征生成模型 $p(x|y, z)$ 在不同领域间保持不变。我们注意到,这对应于广泛使用的“标签偏移”假设的扩展版本,其中标签现在也包括干扰因素 $z$。基于这一观察,我们提出一种测试时标签偏移校正方法,该方法通过将EM算法应用于目标域分布 $p_t(x)$ 的无标签样本来适应联合分布 $p(y, z)$ 的变化。重要的是,我们无需拟合生成模型 $p(x|y,z)$,而只需对在源分布上训练的判别模型 $p_s(y,z|x)$ 的输出进行加权。我们在多个标准图像和文本数据集以及CheXpert胸部X光数据集上评估了我们的方法——称为“测试时标签偏移自适应”(TTLSA),结果表明,相比那些旨在应对分布变化的不变性方法以及基准经验风险最小化方法,该方法提升了性能。实验复现代码见 https://github.com/nalzok/test-time-label-shift 。