Large language models (LLMs) with long context windows have gained significant attention. However, the KV cache, stored to avoid re-computation, becomes a bottleneck. Various dynamic sparse or TopK-based attention approximation methods have been proposed to leverage the common insight that attention is sparse. In this paper, we first show that TopK attention itself suffers from quality degradation in certain downstream tasks because attention is not always as sparse as expected. Rather than selecting the keys and values with the highest attention scores, sampling with theoretical guarantees can provide a better estimation for attention output. To make the sampling-based approximation practical in LLM generation, we propose MagicPIG, a heterogeneous system based on Locality Sensitive Hashing (LSH). MagicPIG significantly reduces the workload of attention computation while preserving high accuracy for diverse tasks. MagicPIG stores the LSH hash tables and runs the attention computation on the CPU, which allows it to serve longer contexts and larger batch sizes with high approximation accuracy. MagicPIG can improve decoding throughput by up to $5\times$ across various GPU hardware and achieve 54ms decoding latency on a single RTX 4090 for Llama-3.1-8B-Instruct model with a context of 96k tokens. The code is available at https://github.com/Infini-AI-Lab/MagicPIG.
翻译:具备长上下文窗口的大语言模型(LLMs)已获得广泛关注。然而,为避免重复计算而存储的KV缓存成为性能瓶颈。基于注意力具有稀疏性的普遍认知,研究者已提出多种动态稀疏或基于TopK的注意力近似方法。本文首先指出,由于注意力并非总是如预期般稀疏,TopK注意力自身在某些下游任务中会遭受质量下降。相较于选择具有最高注意力分数的键和值,具备理论保证的采样方法能为注意力输出提供更优的估计。为使基于采样的近似方法能实际应用于LLM生成,我们提出MagicPIG——一个基于局部敏感哈希(LSH)的异构系统。MagicPIG在显著降低注意力计算工作负载的同时,为多样化任务保持了高精度。该系统在CPU上存储LSH哈希表并运行注意力计算,从而能够以高近似精度处理更长上下文和更大批次规模。在不同GPU硬件上,MagicPIG最高可将解码吞吐量提升$5\times$;对于Llama-3.1-8B-Instruct模型(96k令牌上下文),在单张RTX 4090上可实现54ms的解码延迟。代码已开源:https://github.com/Infini-AI-Lab/MagicPIG。