Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.
翻译:基于注意力机制的神经网络(如Transformer)展现出卓越的上下文学习能力:给定来自未知任务的简短提示序列,它们无需参数更新即可生成相关的逐标记和下一标记预测。通过将标记的训练数据序列与未标记的测试数据嵌入为提示,这使得Transformer能像监督学习算法般运作。实际上,近期研究表明,当在随机线性回归问题上训练Transformer架构时,其预测结果与普通最小二乘法高度一致。为理解这一现象的内在机制,我们研究了单层线性自注意力Transformer在线性回归任务中通过梯度流训练时的动态过程。研究发现,尽管目标函数非凸,具有适当随机初始化的梯度流仍能收敛到全局最小值。在该全局最小值处,当给定来自新预测任务的标记样本提示时,Transformer的预测误差与针对测试提示分布的最佳线性预测器相当。我们进一步刻画了训练后Transformer对多种分布偏移的鲁棒性:尽管能容忍部分偏移,但对提示协变量分布的偏移缺乏抵抗力。基于此,我们考虑了一种广义的上下文学习场景,其中不同提示的协变量分布可存在差异。研究表明,虽然梯度流在该场景下仍能成功找到全局最小值,但训练后的Transformer在轻微协变量偏移下仍显脆弱。我们通过大规模非线性Transformer架构的实验补充了这一发现,证明此类架构在协变量偏移下更具鲁棒性。