Linear attention offers a computationally efficient yet expressive alternative to softmax attention. However, recent empirical results indicate that the state of trained linear attention models often exhibits a low-rank structure, suggesting that these models underexploit their capacity in practice. To illuminate this phenomenon, we provide a theoretical analysis of the role of rank in linear attention, revealing that low effective rank can affect retrieval error by amplifying query noise. In addition to these theoretical insights, we conjecture that the low-rank states can be substantially reduced post-training with only minimal performance degradation, yielding faster and more memory-efficient models. To this end, we propose a novel hardware-aware approach that structurally prunes key and query matrices, reducing the state size while retaining compatibility with existing CUDA kernels. We adapt several existing pruning strategies to fit our framework and, building on our theoretical analysis, propose a novel structured pruning method based on a rank-revealing QR decomposition. Our empirical results, evaluated across models of varying sizes and on various downstream tasks, demonstrate the effectiveness of our state reduction framework. We highlight that our framework enables the removal of 50% of the query and key channels at only a marginal increase in perplexity. The code for this project can be found at https://github.com/camail-official/LinearAttentionPruning.
翻译:线性注意力为Softmax注意力提供了一种计算高效且表达能力强的替代方案。然而,最近的实证结果表明,训练后的线性注意力模型的状态通常呈现出低秩结构,这表明这些模型在实践中未能充分利用其容量。为阐明这一现象,我们对秩在线性注意力中的作用进行了理论分析,揭示了低有效秩会通过放大查询噪声来影响检索误差。除了这些理论见解,我们推测,低秩状态可以在训练后大幅缩减,而仅带来极小的性能下降,从而得到更快、更节省内存的模型。为此,我们提出了一种新颖的硬件感知方法,该方法对键矩阵和查询矩阵进行结构化剪枝,在保持与现有CUDA内核兼容性的同时减小状态大小。我们调整了几种现有的剪枝策略以适应我们的框架,并基于我们的理论分析,提出了一种基于秩揭示QR分解的新型结构化剪枝方法。我们在不同规模的模型和各种下游任务上评估了我们的实证结果,证明了我们状态缩减框架的有效性。我们强调,我们的框架能够移除50%的查询和键通道,而困惑度仅边际增加。本项目的代码可在 https://github.com/camail-official/LinearAttentionPruning 找到。