Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks--referred to as Streaming Heads--do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models while speeding up decoding by up to 2.18x and 1.50x and accelerating pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU. Code is provided in https://github.com/mit-han-lab/duo-attention.
翻译:部署长上下文大语言模型至关重要,但带来了巨大的计算与内存挑战。在所有注意力头上缓存全部键值状态会消耗大量内存。现有的键值缓存剪枝方法要么损害大语言模型的长上下文能力,要么仅提供有限的效率提升。本文发现,仅有部分注意力头(即检索头)对处理长上下文至关重要,需要跨所有令牌进行完整注意力计算。相比之下,其余所有注意力头(主要关注近期令牌和注意力汇聚点,称为流式头)则无需完整注意力。基于这一洞察,我们提出了DuoAttention框架,该框架仅对检索头应用完整的键值缓存,而对流式头使用轻量级、恒定长度的键值缓存,从而在不损害模型长上下文能力的前提下,降低大语言模型解码与预填充阶段的内存占用和延迟。DuoAttention采用一种基于优化的轻量级算法,结合合成数据来准确识别检索头。我们的方法显著降低了长上下文推理内存:对于MHA模型最高降低2.55倍,对于GQA模型最高降低1.67倍;同时解码速度最高提升2.18倍(MHA)和1.50倍(GQA),预填充速度最高提升1.73倍(MHA)和1.63倍(GQA),且与完整注意力相比精度损失极小。值得注意的是,结合量化技术,DuoAttention使得Llama-3-8B模型能够在单张A100 GPU上处理长达330万令牌的上下文进行解码。代码发布于 https://github.com/mit-han-lab/duo-attention。