Transformer architecture has shown impressive performance in multiple research domains and has become the backbone of many neural network models. However, there is limited understanding on how it works. In particular, with a simple predictive loss, how the representation emerges from the gradient \emph{training dynamics} remains a mystery. In this paper, for 1-layer transformer with one self-attention layer plus one decoder layer, we analyze its SGD training dynamics for the task of next token prediction in a mathematically rigorous manner. We open the black box of the dynamic process of how the self-attention layer combines input tokens, and reveal the nature of underlying inductive bias. More specifically, with the assumption (a) no positional encoding, (b) long input sequence, and (c) the decoder layer learns faster than the self-attention layer, we prove that self-attention acts as a \emph{discriminative scanning algorithm}: starting from uniform attention, it gradually attends more to distinct key tokens for a specific next token to be predicted, and pays less attention to common key tokens that occur across different next tokens. Among distinct tokens, it progressively drops attention weights, following the order of low to high co-occurrence between the key and the query token in the training set. Interestingly, this procedure does not lead to winner-takes-all, but decelerates due to a \emph{phase transition} that is controllable by the learning rates of the two layers, leaving (almost) fixed token combination. We verify this \textbf{\emph{scan and snap}} dynamics on synthetic and real-world data (WikiText).
翻译:Transformer架构已在多个研究领域展现出卓越性能,并成为众多神经网络模型的基础架构。然而,其工作机制仍缺乏充分理解。特别是,在简单预测损失函数下,表征如何从梯度训练动态中涌现仍是一个谜团。本文针对含一个自注意力层加一个解码器层的单层Transformer,以数学严谨的方式分析了其在下个词元预测任务上的随机梯度下降训练动态。我们打开了自注意力层如何组合输入词元的动态过程黑箱,揭示了潜在归纳偏置的本质。具体而言,在假设(a)无位置编码、(b)长输入序列、(c)解码器层学习速度优于自注意力层的条件下,我们证明自注意力机制表现为一种**判别式扫描算法**:从均匀注意力开始,逐步对特定待预测下个词元对应的区分性键词元赋予更高关注,而对不同下个词元共现的共性键词元关注度降低。在区分性词元中,注意力权重按训练集中键词元与查询词元共现频率从低到高的顺序逐步衰减。有趣的是,这一过程不会导致胜者全得效应,而是因两层学习率可控的**相变**而减速,留下(近乎)固定的词元组合。我们通过合成数据与真实数据(WikiText)验证了这一**扫描与定格**动态过程。