At present, the mechanisms of in-context learning in Transformers are not well understood and remain mostly an intuition. In this paper, we suggest that training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradient-descent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of in-context learning in optimized Transformers. Building on this insight, we furthermore identify how Transformers surpass the performance of plain gradient descent by learning an iterative curvature correction and learn linear models on deep data representations to solve non-linear regression tasks. Finally, we discuss intriguing parallels to a mechanism identified to be crucial for in-context learning termed induction-head (Olsson et al., 2022) and show how it could be understood as a specific case of in-context learning by gradient descent learning within Transformers. Code to reproduce the experiments can be found at https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd .
翻译:目前,Transformer 中上下文学习的内在机制尚未被充分理解,大多停留在直觉层面。本文提出,基于自回归目标训练的 Transformer 与基于梯度的元学习公式密切相关。我们首先通过一个简单的权重构造,证明了以下两种数据变换的等价性:1) 单个线性自注意力层诱导的变换;2) 回归损失上的梯度下降(GD)诱导的变换。基于这一构造,我们通过实验证明,在简单回归任务上仅训练自注意力机制的 Transformer 时,无论是 GD 学习的模型还是 Transformer 学习的模型都表现出高度相似性,甚至优化得到的权重与构造完全匹配。由此,我们揭示了训练后的 Transformer 如何成为元优化器——即在前向传播过程中通过梯度下降学习模型。这使我们至少在回归问题领域,能够从机理层面理解优化后 Transformer 中上下文学习的内部运作机制。基于这一洞见,我们进一步发现 Transformer 如何通过学习迭代曲率校正来超越简单梯度下降的性能,并通过学习深层数据表示的线性模型来解决非线性回归任务。最后,我们讨论了与一个被识别为上下文学习关键机制的"诱导头"(Olsson 等,2022)之间有趣的相似性,并展示了该机制如何被理解为 Transformer 内部通过梯度下降进行上下文学习的一种特例。实验复现代码参见 https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd。