The main challenge in domain generalization (DG) is to handle the distribution shift problem that lies between the training and test data. Recent studies suggest that test-time training (TTT), which adapts the learned model with test data, might be a promising solution to the problem. Generally, a TTT strategy hinges its performance on two main factors: selecting an appropriate auxiliary TTT task for updating and identifying reliable parameters to update during the test phase. Both previous arts and our experiments indicate that TTT may not improve but be detrimental to the learned model if those two factors are not properly considered. This work addresses those two factors by proposing an Improved Test-Time Adaptation (ITTA) method. First, instead of heuristically defining an auxiliary objective, we propose a learnable consistency loss for the TTT task, which contains learnable parameters that can be adjusted toward better alignment between our TTT task and the main prediction task. Second, we introduce additional adaptive parameters for the trained model, and we suggest only updating the adaptive parameters during the test phase. Through extensive experiments, we show that the proposed two strategies are beneficial for the learned model (see Figure 1), and ITTA could achieve superior performance to the current state-of-the-art methods on several DG benchmarks. Code is available at https://github.com/liangchen527/ITTA.
翻译:域泛化(DG)面临的核心挑战在于处理训练数据与测试数据之间的分布偏移问题。近期研究表明,通过测试数据对已学习模型进行自适应的测试时训练(TTT)策略,有望成为解决该问题的有效方案。通常,TTT策略的性能取决于两个关键因素:选择适当的辅助TTT任务进行更新,以及确定测试阶段需更新的可靠参数。前期研究与我们的实验均表明,若未妥善考虑上述两个因素,TTT不仅无法提升模型性能,反而可能对其产生损害。针对这两个因素,本文提出改进的测试时自适应(ITTA)方法。首先,我们摒弃启发式定义辅助目标函数的传统做法,为TTT任务设计了一种可学习的一致性损失函数,该函数包含可学习参数,能够通过调整实现TTT任务与主预测任务之间的更好对齐。其次,我们为训练模型引入额外的自适应参数,并建议在测试阶段仅更新这些自适应参数。通过大量实验证明,所提出的两种策略对已学习模型具有积极效果(见图1),且ITTA方法在多个DG基准测试中均能达到优于当前最优方法的性能。代码已开源至https://github.com/liangchen527/ITTA。