A key advantage of Recurrent Neural Networks (RNNs) over Transformers is their linear computational and space complexity enables faster training and inference for long sequences. However, RNNs are fundamentally unable to randomly access historical context, and simply integrating attention mechanisms may undermine their efficiency advantages. To overcome this limitation, we propose Hierarchical Sparse Attention (HSA), a novel attention mechanism that enhances RNNs with long-range random access flexibility while preserving their merits in efficiency and length generalization. HSA divides inputs into chunks, selects the top-$k$ chunks and hierarchically aggregates information. The core innovation lies in learning token-to-chunk relevance based on fine-grained token-level information inside each chunk. This approach enhances the precision of chunk selection across both in-domain and out-of-domain context lengths. To make HSA efficient, we further introduce a hardware-aligned kernel design. By combining HSA with Mamba, we introduce RAMba, which achieves perfect accuracy in passkey retrieval across 64 million contexts despite pre-training on only 4K-length contexts, and significant improvements on various downstream tasks, with nearly constant memory footprint. These results show RAMba's huge potential in long-context modeling.
翻译:循环神经网络(RNNs)相较于Transformer模型的一个关键优势在于其线性的计算与空间复杂度,这使得其在长序列处理中能够实现更快的训练与推理速度。然而,RNNs本质上无法随机访问历史上下文信息,而简单地集成注意力机制可能会削弱其效率优势。为克服这一局限,我们提出了层次化稀疏注意力(Hierarchical Sparse Attention, HSA),这是一种新颖的注意力机制,它在保持RNNs效率和长度泛化优势的同时,为其增强了长距离随机访问的灵活性。HSA将输入序列划分为多个数据块,选取前k个最相关的数据块,并通过层次化方式聚合信息。其核心创新在于基于每个数据块内部细粒度的词元级信息来学习词元到数据块的相关性。这一方法提升了在领域内及领域外上下文长度下数据块选择的精确度。为实现HSA的高效运行,我们进一步引入了硬件对齐的内核设计。通过将HSA与Mamba模型结合,我们提出了RAMba模型。该模型尽管仅在4K长度的上下文上进行预训练,却在6400万长度的上下文上实现了密码检索任务的完美准确率,并在多种下游任务中取得了显著性能提升,同时保持了近乎恒定的内存占用。这些结果表明RAMba在长上下文建模方面具有巨大潜力。