Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: $\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$ $\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$ $\textit{next-token}$ $\textit{prediction?}$ We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$ $\textbf{retrieval:}$ Given input sequence, self-attention precisely selects the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
翻译:基于Transformer的语言模型通过大规模数据集训练,通过给定输入序列预测下一个词元。尽管训练目标简单,却推动了自然语言处理的革命性进展。其成功的关键在于自注意力机制。本研究探讨:$\textit{单个自注意力层从下一词元预测中学到了什么?}$ 我们证明,通过梯度下降训练的自注意力层会学习一种自动机,通过两个不同步骤生成下一个词元:$\textbf{(1)}$ $\textbf{硬检索:}$ 给定输入序列,自注意力层精确选择与最后输入词元相关联的$\textit{高优先级输入词元}$。$\textbf{(2)}$ $\textbf{软组合:}$ 随后,它创建高优先级词元的凸组合,从中可采样得到下一个词元。在适当条件下,我们通过从训练数据中提取的词元有向图严格刻画这些机制。我们证明梯度下降隐式发现该图的强连通分量(SCC),自注意力层学会检索上下文窗口中最高优先级SCC中的词元。我们的理论将模型权重分解为对应硬检索和软组合步骤的方向性分量与有限分量。这形式化了[Tarzanagh等人2023]中推测的隐式偏差公式。我们希望这些发现能阐明自注意力处理序列数据的机制,并为解密更复杂架构铺平道路。