Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear (with respect to output length) inference complexity. Recent works such as RetNet (Sun et al., 2023) and TransNormerLLM (Qin et al., 2023a) observe that adding a global decay term to the additive RNN update rule greatly improves performance, sometimes outperforming standard Transformers with softmax attention when trained at scale. In this work we show that adding a data-dependent gating mechanism further improves performance. We derive a parallel form of this gated linear attention layer that enables efficient training. However, a straightforward, numerically stable implementation of this parallel form requires generalized matrix multiplications in log-space for numerical stability, and thus cannot take advantage of tensor cores on modern GPUs which are optimized for standard matrix multiplications. We develop a hardware-efficient version of the parallel form that can still make use of tensor cores through block-parallel computations over sequence chunks. Experiments on moderate-scale language modeling (340M-parameter models trained on 15B tokens, 1.3B-parameter models trained on 100B tokens) show that gated linear attention (GLA) Transformers perform competitively against a strong LLaMA-architecture Transformer baseline (Touvron et al., 2023) as well as Mamba (Gu & Dao, 2023), a recently introduced state-space model with a data-dependent state transition mechanism. For training speed, our Triton-based implementation performs comparably to CUDA-optimized FlashAttention-2 (Dao, 2023) under the regular 2048 training length setting, while outperforming FlashAttention-2 when training on longer sequences beyond 4096.
翻译:具有线性注意力的Transformer不仅支持高效的并行训练,还能同时表示为具有二维(矩阵值)隐藏状态的RNN,从而在推理时实现与输出长度呈线性关系的复杂度。近期工作如RetNet(Sun等,2023)和TransNormerLLM(Qin等,2023a)发现,在加性RNN更新规则中加入全局衰减项可大幅提升性能,有时在规模化训练中甚至超越采用softmax注意力的标准Transformer。本研究证明,引入依赖于数据的门控机制可进一步改善性能。我们推导出这种门控线性注意力层的并行形式,从而实现高效训练。然而,该并行形式在数值稳定性上需要采用对数空间中的广义矩阵乘法,因此无法利用现代GPU上专为标准矩阵乘法优化的张量核心。我们开发了一种硬件高效的并行实现版本,通过序列分块的块并行计算仍能使用张量核心。中等规模语言建模实验(参数为3.4亿的模型在150亿词元上训练、参数为13亿的模型在1000亿词元上训练)表明:门控线性注意力(GLA)Transformer在性能上与强基线LLaMA架构Transformer(Touvron等,2023)以及近期引入的具有数据依赖状态转移机制的Mamba状态空间模型(Gu & Dao,2023)相比具有竞争力。在训练速度方面,我们基于Triton的实现与CUDA优化后的FlashAttention-2(Dao,2023)在常规2048训练长度设置下性能相当,而在超过4096的长序列训练中则优于FlashAttention-2。