Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8$\times$ decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.
翻译:大型语言模型的长上下文推理在解码阶段受到键值(KV)缓存加载的瓶颈制约,其中生成的序列性要求在每一步都将KV缓存从片外高带宽存储器(HBM)重复传输至片内静态随机存取存储器(SRAM)。虽然多头潜在注意力(MLA)显著减少了KV缓存的总大小,但在通过张量并行(TP)进行分布式解码时,它会遇到分片瓶颈。由于其单个潜在头无法被分区,每个设备被迫为每个令牌冗余加载完整的KV缓存,消耗过多的内存流量并削弱了权重分片等TP优势。在本工作中,我们提出了多头低秩注意力(MLRA),它实现了可分区潜在状态,以实现高效的4路TP解码。大量实验表明,MLRA在困惑度和下游任务性能上达到了最先进水平,同时相比MLA实现了2.8$\times$的解码加速。代码可在 https://github.com/SongtaoLiu0823/MLRA 获取。预训练权重以及训练和评估数据可在 https://huggingface.co/Soughing/MLRA 获取。