Recurrent large language models that compete with Transformers in language modeling perplexity are emerging at a rapid rate (e.g., Mamba, RWKV). Excitingly, these architectures use a constant amount of memory during inference. However, due to the limited memory, recurrent LMs cannot recall and use all the information in long contexts leading to brittle in-context learning (ICL) quality. A key challenge for efficient LMs is selecting what information to store versus discard. In this work, we observe the order in which information is shown to the LM impacts the selection difficulty. To formalize this, we show that the hardness of information recall reduces to the hardness of a problem called set disjointness (SD), a quintessential problem in communication complexity that requires a streaming algorithm (e.g., recurrent model) to decide whether inputted sets are disjoint. We empirically and theoretically show that the recurrent memory required to solve SD changes with set order, i.e., whether the smaller set appears first in-context. Our analysis suggests, to mitigate the reliance on data order, we can put information in the right order in-context or process prompts non-causally. Towards that end, we propose: (1) JRT-Prompt, where context gets repeated multiple times in the prompt, effectively showing the model all data orders. This gives $11.0 \pm 1.3$ points of improvement, averaged across $16$ recurrent LMs and the $6$ ICL tasks, with $11.9\times$ higher throughput than FlashAttention-2 for generation prefill (length $32$k, batch size $16$, NVidia H100). We then propose (2) JRT-RNN, which uses non-causal prefix-linear-attention to process prompts and provides $99\%$ of Transformer quality at $360$M params., $30$B tokens and $96\%$ at $1.3$B params., $50$B tokens on average across the tasks, with $19.2\times$ higher throughput for prefill than FA2.
翻译:与Transformer在语言建模困惑度上竞争的循环大语言模型正快速涌现(如Mamba、RWKV)。令人振奋的是,这些架构在推理时仅需恒定内存。然而,由于内存有限,循环语言模型无法召回并利用长上下文中的所有信息,导致上下文学习质量不稳定。高效语言模型的一个关键挑战在于选择存储与丢弃哪些信息。本工作中,我们观察到信息呈现给语言模型的顺序会影响选择难度。为形式化此现象,我们证明信息召回的难度可归结为集合不相交问题——这是通信复杂度中的一个典型问题,要求流式算法(如循环模型)判断输入的集合是否不相交。我们通过实验与理论证明,解决集合不相交问题所需的循环内存随集合顺序变化,即较小集合是否在上下文中先出现。分析表明,为减轻对数据顺序的依赖,可在上下文中调整信息顺序或以非因果方式处理提示。基于此,我们提出:(1)JRT-Prompt方法,通过多次重复上下文内容,使模型接触所有数据顺序。该方法在16个循环语言模型和6个上下文学习任务上平均带来$11.0 \pm 1.3$分的提升,且在生成预填充阶段(长度32k,批量大小16,NVidia H100)吞吐量比FlashAttention-2高$11.9$倍。随后提出(2)JRT-RNN架构,采用非因果前缀线性注意力处理提示,在360M参数/30B token规模下平均达到Transformer 99%的性能,在1.3B参数/50B token规模下达到96%性能,其预填充吞吐量比FlashAttention-2高$19.2$倍。