Understanding the learning process and the embedded computation in transformers is becoming a central goal for the development of interpretable AI. In the present study, we introduce a hierarchical filtering procedure for generative models of sequences on trees, allowing us to hand-tune the range of positional correlations in the data. Leveraging this controlled setting, we provide evidence that vanilla encoder-only transformers can approximate the exact inference algorithm when trained on root classification and masked language modeling tasks, and study how this computation is discovered and implemented. We find that correlations at larger distances, corresponding to increasing layers of the hierarchy, are sequentially included by the network during training. Moreover, by comparing attention maps from models trained with varying degrees of filtering and by probing the different encoder levels, we find clear evidence of a reconstruction of correlations on successive length scales corresponding to the various levels of the hierarchy, which we relate to a plausible implementation of the exact inference algorithm within the same architecture.
翻译:理解Transformer的学习过程及其内部计算机制正成为发展可解释人工智能的核心目标。在本研究中,我们针对树结构序列生成模型提出了一种层级过滤方法,该方法允许我们手动调节数据中位置相关性的范围。利用这一受控实验环境,我们证明了在根节点分类和掩码语言建模任务上训练的普通编码器-仅Transformer能够逼近精确推理算法,并研究了该计算过程是如何被发现和实现的。我们发现,对应更高层级的长程相关性会在训练过程中被网络依次纳入。此外,通过比较不同过滤程度下训练得到的注意力图,并对编码器各层级进行探测分析,我们发现了明确的证据表明网络在连续长度尺度上重构了与层级结构相对应的相关性,这可以解释为同一架构内精确推理算法的一种合理实现方式。