Test-time adaptation (TTA) is an effective approach to mitigate performance degradation of trained models when encountering input distribution shifts at test time. However, existing TTA methods often suffer significant performance drops when facing additional class distribution shifts. We first analyze TTA methods under label distribution shifts and identify the presence of class-wise confusion patterns commonly observed across different covariate shifts. Based on this observation, we introduce label Distribution shift-Aware prediction Refinement for Test-time adaptation (DART), a novel TTA method that refines the predictions by focusing on class-wise confusion patterns. DART trains a prediction refinement module during an intermediate time by exposing it to several batches with diverse class distributions using the training dataset. This module is then used during test time to detect and correct class distribution shifts, significantly improving pseudo-label accuracy for test data. Our method exhibits 5-18% gains in accuracy under label distribution shifts on CIFAR-10C, without any performance degradation when there is no label distribution shift. Extensive experiments on CIFAR, PACS, OfficeHome, and ImageNet benchmarks demonstrate DART's ability to correct inaccurate predictions caused by test-time distribution shifts. This improvement leads to enhanced performance in existing TTA methods, making DART a valuable plug-in tool.
翻译:测试时适应(TTA)是一种有效的方法,用于缓解训练模型在测试时遇到输入分布偏移时的性能下降。然而,现有的TTA方法在面对额外的类别分布偏移时,常常会遭受显著的性能下降。我们首先分析了标签分布偏移下的TTA方法,并识别出在不同协变量偏移中普遍存在的类别混淆模式。基于这一观察,我们引入了标签分布偏移感知的预测优化用于测试时适应(DART),这是一种新颖的TTA方法,通过聚焦于类别混淆模式来优化预测。DART在中间阶段通过使用训练数据集将其暴露于多个具有不同类别分布的批次中,来训练一个预测优化模块。该模块随后在测试时用于检测和校正类别分布偏移,显著提高了测试数据的伪标签准确性。我们的方法在CIFAR-10C上面对标签分布偏移时,准确率提升了5-18%,并且在无标签分布偏移时没有任何性能下降。在CIFAR、PACS、OfficeHome和ImageNet基准测试上的大量实验证明了DART能够校正由测试时分布偏移引起的不准确预测。这一改进提升了现有TTA方法的性能,使DART成为一个有价值的即插即用工具。