Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory (i) delineates the role of alignment between pretraining distribution and target task, (ii) demystifies how TTT can alleviate distribution shift, and (iii) quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.
翻译:测试时训练方法通过显式更新模型权重以适应特定测试实例,已在多种场景中取得成功,最近更在语言建模与推理任务中表现突出。为揭示其成功机理,本研究针对上下文学习场景探究基于梯度的测试时训练算法,该算法利用测试提示中提供的上下文示例对Transformer模型进行训练。具体而言,我们在线性Transformer采用单步梯度更新规则时,建立了完整的理论刻画体系。我们的理论框架:(一)阐明了预训练分布与目标任务间对齐机制的作用;(二)揭示了测试时训练缓解分布偏移的内在机理;(三)量化了测试时训练的样本复杂度,证明其能显著降低上下文学习所需样本规模。在实证研究方面,我们探究了测试时训练对表格基础模型TabPFN的优化效果。实验结果表明:测试时训练能显著降低表格分类任务所需样本量(减少3至5倍),在可忽略的训练成本下实现显著的推理效率提升,这与我们的理论分析高度吻合。