The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide \emph{PolySketchFormer}, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves a 2.5-4x speedup in training compared to FlashAttention, with no observed degradation in quality across our experiments.
翻译:摘要:自注意力机制相对于序列长度的二次时间与内存复杂度,成为大规模基于Transformer的语言模型训练与部署中的关键计算瓶颈。近期理论结果表明,在合理的复杂度假设下,次二次softmax注意力近似具有难以处理性。本文通过首先证明高次多项式注意力能在不牺牲模型质量的前提下有效替代softmax来应对这一挑战。接着,我们利用数值线性代数中的多项式草图技术,实现了具有近似保证的线性时间多项式注意力。关键在于,我们的方法无需稀疏化注意力矩阵即可实现加速。我们还提出一种基于块的算法,以高效应用因果掩码。结合这些技术,我们提供了PolySketchFormer,一种面向语言建模、具备可证明保证的实用线性时间Transformer架构。我们通过训练能够处理长上下文的语言模型,在Google Cloud TPU上使用合成及真实数据集(PG19、Wikipedia和C4)对PolySketchFormer进行实验验证。对于32k上下文长度和GPT-2风格模型,与FlashAttention相比,我们的模型在训练中实现了2.5-4倍的加速,且实验中未观察到质量下降。