Recurrent neural networks (RNNs), such as linear attention and state-space models, have gained popularity due to their constant per-token complexity when processing long contexts. However, these recurrent models struggle with tasks that require accurate recall of contextual information from long contexts, because all contextual information is compressed into a fixed-size recurrent state. Previous studies have shown that recall ability is positively correlated with the recurrent state size, yet directly training RNNs with large recurrent states results in high training costs. In this paper, we introduce StateX, a post-training framework that efficiently expands the states of pre-trained RNNs. For two popular classes of RNNs, linear attention and state-space models, we design post-training architectural modifications in StateX, to scale up the state size with no or negligible increase in model parameters. Experiments on models with up to 1.3B parameters demonstrate that StateX efficiently enhances the recall and in-context learning performance of RNNs without incurring high post-training costs or compromising other capabilities.
翻译:[translated abstract in Chinese]
循环神经网络(RNNs),如线性注意力机制和状态空间模型,因其在处理长上下文时具有恒定的每词元复杂度而受到广泛关注。然而,此类循环模型难以准确召回长上下文中的信息,这是因为所有上下文信息都被压缩至固定大小的循环状态中。已有研究表明,召回能力与循环状态大小呈正相关,但直接训练具有大状态容量的RNN会导致高昂的训练成本。本文提出StateX——一种高效扩展预训练RNN状态的训练后框架。针对两类主流RNN(线性注意力机制与状态空间模型),我们在StateX中设计了训练后的架构修改,使得在不增加或仅微量增加模型参数的前提下扩展状态容量。在参数规模达1.3B的模型上的实验表明,StateX能够高效提升RNN的召回能力和上下文学习性能,且不会带来高昂的训练后成本或损害其他能力。