Large language models (LLMs) with long context windows have gained significant attention. However, the KV cache, stored to avoid re-computation, becomes a bottleneck. Various dynamic sparse or TopK-based attention approximation methods have been proposed to leverage the common insight that attention is sparse. In this paper, we first show that TopK attention itself suffers from quality degradation in certain downstream tasks because attention is not always as sparse as expected. Rather than selecting the keys and values with the highest attention scores, sampling with theoretical guarantees can provide a better estimation for attention output. To make the sampling-based approximation practical in LLM generation, we propose MagicPIG, a heterogeneous system based on Locality Sensitive Hashing (LSH). MagicPIG significantly reduces the workload of attention computation while preserving high accuracy for diverse tasks. MagicPIG stores the LSH hash tables and runs the attention computation on the CPU, which allows it to serve longer contexts and larger batch sizes with high approximation accuracy. MagicPIG can improve decoding throughput by $1.9\sim3.9\times$ across various GPU hardware and achieve 110ms decoding latency on a single RTX 4090 for Llama-3.1-8B-Instruct model with a context of 96k tokens. The code is available at \url{https://github.com/Infini-AI-Lab/MagicPIG}.
翻译:具备长上下文窗口的大语言模型已获得广泛关注。然而,为避免重复计算而存储的键值缓存成为性能瓶颈。基于注意力机制通常具有稀疏性的普遍认知,研究者已提出多种动态稀疏或基于TopK的注意力近似方法。本文首先指出,由于注意力并非总是如预期般稀疏,TopK注意力自身在某些下游任务中会遭受质量下降。相较于选择具有最高注意力分数的键和值,具备理论保证的采样方法能为注意力输出提供更好的估计。为使基于采样的近似方法在大语言模型生成中实用化,我们提出MagicPIG——一个基于局部敏感哈希的异构系统。MagicPIG在显著降低注意力计算工作负载的同时,为多样化任务保持了高精度。该系统在CPU上存储LSH哈希表并运行注意力计算,使其能够以高近似精度处理更长上下文和更大批次规模。在不同GPU硬件上,MagicPIG可将解码吞吐量提升$1.9\sim3.9$倍,并在单张RTX 4090上为Llama-3.1-8B-Instruct模型实现96K令牌上下文下的110毫秒解码延迟。代码发布于\url{https://github.com/Infini-AI-Lab/MagicPIG}。