Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear (with respect to output length) inference complexity. Recent works such as RetNet (Sun et al., 2023) and TransNormerLLM (Qin et al., 2023a) observe that adding a global decay term to the additive RNN update rule greatly improves performance, sometimes outperforming standard Transformers with softmax attention when trained at scale. In this work we show that adding a data-dependent gating mechanism further improves performance. We derive a parallel form of this gated linear attention layer that enables efficient training. However, a straightforward, numerically stable implementation of this parallel form requires generalized matrix multiplications in log-space for numerical stability, and thus cannot take advantage of tensor cores on modern GPUs which are optimized for standard matrix multiplications. We develop a hardware-efficient version of the parallel form that can still make use of tensor cores through block-parallel computations over sequence chunks. Experiments on moderate-scale language modeling (340M-parameter models trained on 15B tokens, 1.3B-parameter models trained on 100B tokens) show that gated linear attention (GLA) Transformers perform competitively against a strong LLaMA-architecture Transformer baseline (Touvron et al., 2023) as well as Mamba (Gu & Dao, 2023), a recently introduced state-space model with a data-dependent state transition mechanism. For training speed, our Triton-based implementation performs comparably to CUDA-optimized FlashAttention-2 (Dao, 2023) under the regular 2048 training length setting, while outperforming FlashAttention-2 when training on longer sequences beyond 4096.
翻译:具有线性注意力的变换器模型既支持高效的并行训练,又可同时建模为具有二维(矩阵值)隐状态的循环神经网络,从而实现了与输出长度成线性关系的推理复杂度。近期研究如RetNet(Sun等,2023)与TransNormerLLM(Qin等,2023a)发现,在加性循环神经网络更新规则中添加全局衰减项能显著提升性能,在大规模训练时甚至可超越采用softmax注意力的标准变换器模型。本研究表明,引入数据驱动的门控机制可进一步提升性能。我们推导出门控线性注意力层的并行形式,从而实现高效训练。然而,该并行形式在数值稳定性要求下需在log空间中执行通用矩阵乘法,因此无法充分利用现代GPU上针对标准矩阵乘法优化的张量核心。为此,我们开发了适用于硬件的并行形式变体,通过基于序列分块的块并行计算,仍可利用张量核心加速。在中等规模语言建模实验(包含15B token训练的340M参数模型与100B token训练的1.3B参数模型)中,门控线性注意力(GLA)变换器与强基线的LLaMA架构变换器(Touvron等,2023)及新近提出的数据驱动状态转移机制的状态空间模型Mamba(Gu & Dao, 2023)相比均展现出竞争性表现。在训练速度方面,基于Triton的实现方案在常规2048训练长度设置下与CUDA优化的FlashAttention-2(Dao,2023)表现相当,而在处理超过4096的长序列训练时则优于FlashAttention-2。