Language models struggle to generalize beyond pretraining context lengths, limiting long-horizon reasoning and retrieval. Continued pretraining on long-context data can help but is expensive due to the quadratic scaling of Attention. We observe that most tokens do not require (Global) Attention over the entire sequence and can rely on local context. Based on this, we propose L2A (Learning To Attend), a layer that enables conditional (token-wise) long-range memory access by deciding when to invoke global attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens. L2A matches the performance of standard long-context training to within 3% while skipping Global Attention for $\sim$80% of tokens, outperforming prior baselines. We also design custom Triton kernels to efficiently implement this token-wise conditional Attention on GPUs, achieving up to $\sim$2x improvements in training throughput and time-to-first-token over FlashAttention. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50% with negligible performance loss.
翻译:语言模型难以泛化超出预训练上下文长度的范围,这限制了其长程推理与检索能力。在长上下文数据上持续预训练虽能缓解此问题,但由于注意力机制存在二次方复杂度,其代价高昂。我们观察到,大多数词元并不需要对整个序列进行(全局)注意力计算,而可以依赖局部上下文。基于此,我们提出L2A(学习何时关注)层,该层通过决策何时调用全局注意力,实现了条件化(按词元)的长程记忆访问。我们在Qwen 2.5和Qwen 3模型上评估L2A,将其有效上下文长度从32K词元扩展至128K词元。L2A在跳过约80%词元的全局注意力计算的同时,其性能与标准长上下文训练的差距在3%以内,优于现有基线方法。我们还设计了定制化的Triton内核,以在GPU上高效实现这种按词元的条件化注意力机制,相比FlashAttention,训练吞吐量和首词生成时间最高可提升约2倍。此外,L2A支持对高度稀疏的全局注意力层进行训练后剪枝,在性能损失可忽略的情况下,将KV缓存内存降低最高达50%。