Efficiently modeling sequences with infinite context length has been a long-standing problem. Past works suffer from either the quadratic computation complexity or the limited extrapolation ability on length generalization. In this work, we present Samba, a simple hybrid architecture that layer-wise combines Mamba, a selective State Space Model (SSM), with Sliding Window Attention (SWA). Samba selectively compresses a given sequence into recurrent hidden states while still maintaining the ability to precisely recall memories with the attention mechanism. We scale Samba up to 3.8B parameters with 3.2T training tokens and show that Samba substantially outperforms the state-of-the-art models based on pure attention or SSMs on a wide range of benchmarks. When trained on 4K length sequences, Samba can be efficiently extrapolated to 256K context length with perfect memory recall and show improved token predictions up to 1M context length. As a linear-time sequence model, Samba enjoys a 3.73x higher throughput compared to Transformers with grouped-query attention when processing user prompts of 128K length, and 3.64x speedup when generating 64K tokens with unlimited streaming. A sample implementation of Samba is publicly available in https://github.com/microsoft/Samba.
翻译:高效建模具有无限上下文长度的序列一直是一个长期存在的问题。以往的研究要么受限于二次计算复杂度,要么在长度外推泛化能力上存在不足。本研究提出Samba,一种简单的混合架构,通过逐层结合选择性状态空间模型Mamba与滑动窗口注意力机制。Samba能够将给定序列选择性压缩为循环隐藏状态,同时仍能通过注意力机制精确回溯记忆信息。我们将Samba扩展至38亿参数规模并使用3.2万亿训练词元进行训练,结果表明Samba在广泛基准测试中显著优于基于纯注意力或状态空间模型的现有最优模型。当使用4K长度序列训练时,Samba可高效外推至256K上下文长度并实现完美记忆回溯,在长达1M的上下文范围内均展现出持续提升的词元预测能力。作为线性时间序列模型,在处理128K长度用户提示时,Samba相比采用分组查询注意力的Transformer模型吞吐量提升3.73倍;在无限流式生成64K词元时实现3.64倍加速。Samba的示例实现已公开于https://github.com/microsoft/Samba。