Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts.
翻译:基于注意力的神经网络(如Transformer)展现出卓越的上下文学习(ICL)能力:给定来自未知任务的短提示词序列,它们能在不更新参数的情况下,生成相关的逐词和下一词预测。通过将带标签的训练数据序列与无标签测试数据嵌入提示词中,Transformer能够像监督学习算法一样运作。实际上,近期的研究表明,当在随机线性回归问题实例上训练Transformer架构时,这些模型的预测结果与普通最小二乘法高度相似。为理解这一现象的底层机制,我们研究了单层线性自注意力Transformer在梯度流训练下的ICL动力学。我们证明,尽管目标函数非凸,但采用合适的随机初始化后,梯度流仍能找到全局最优解。在此全局最优解下,当给定新预测任务中带标签示例的测试提示词时,Transformer的预测误差可与测试提示词分布上的最佳线性预测器相媲美。此外,我们还刻画了训练后的Transformer对多种分布偏移的鲁棒性,发现尽管模型能容忍若干偏移类型,但无法应对提示词协变量分布的偏移。基于此,我们考虑了一种广义ICL设置,其中不同提示词的协变量分布可发生变化。结果表明,尽管梯度流在该设置下仍能收敛到全局最优解,但训练后的Transformer在轻微协变量偏移下依然脆弱。