Efficient state space models (SSMs), such as linear recurrent neural networks and linear attention variants, offer computational advantages over Transformers but struggle with tasks requiring long-range in-context retrieval-like text copying, associative recall, and question answering over long contexts. Previous efforts to address these challenges have focused on architectural modifications, often reintroducing computational inefficiencies. In this paper, we propose a novel training procedure, Birdie, that significantly enhances the in-context retrieval capabilities of SSMs without altering their architecture. Our approach combines bidirectional input processing with dynamic mixtures of specialized pre-training objectives, optimized via reinforcement learning. We introduce a new bidirectional SSM architecture that seamlessly transitions from bidirectional context processing to causal generation. Experimental evaluations demonstrate that Birdie markedly improves performance on retrieval-intensive tasks such as multi-number phone book lookup, long paragraph question-answering, and infilling. This narrows the performance gap with Transformers, while retaining computational efficiency. Our findings highlight the importance of training procedures in leveraging the fixed-state capacity of SSMs, offering a new direction to advance their capabilities. All code and pre-trained models are available at https://www.github.com/samblouir/birdie, with support for JAX and PyTorch.
翻译:高效状态空间模型(如线性循环神经网络及其线性注意力变体)虽在计算效率上优于Transformer,但在需要长程上下文检索的任务(如文本复制、关联回忆和长上下文问答)中表现欠佳。先前研究多通过架构修改应对这些挑战,但常重新引入计算低效问题。本文提出一种新颖的训练方法Birdie,该方法在不改变模型架构的前提下,显著提升了状态空间模型的上下文检索能力。我们的方法结合了双向输入处理与动态混合的专用预训练目标,并通过强化学习进行优化。我们提出了一种新型双向状态空间架构,可实现从双向上下文处理到因果生成的无缝切换。实验评估表明,Birdie在多号码电话簿查询、长段落问答及文本填充等检索密集型任务上性能显著提升,在保持计算效率的同时缩小了与Transformer的性能差距。本研究揭示了训练方法在利用状态空间模型固定状态容量方面的重要性,为推进其能力提供了新方向。所有代码与预训练模型均发布于https://www.github.com/samblouir/birdie,支持JAX与PyTorch框架。