We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.
翻译:我们提出了一种名为HyperAttention的近似注意力机制,以应对大型语言模型(LLMs)中日益复杂的长上下文所带来的计算挑战。近期研究表明,在最坏情况下,注意力矩阵的条目若有界或矩阵的稳定秩较低,则需要二次时间。我们引入两个参数来度量:(1)归一化注意力矩阵中的最大列范数,以及(2)在检测并移除大条目后,未归一化注意力矩阵的行范数之比。我们利用这些细粒度参数来捕捉问题的难度。尽管此前有下界限制,但只要上述参数较小,我们仍能在矩阵条目无界或稳定秩较大的情况下实现线性时间采样算法。HyperAttention采用模块化设计,能够轻松集成其他快速底层实现,尤其是FlashAttention。实验上,通过使用局部敏感哈希(LSH)识别大条目,HyperAttention的性能优于现有方法,相比FlashAttention等最先进解决方案,其速度提升显著。我们在多种不同长上下文数据集上验证了HyperAttention的实证性能。例如,在32k上下文长度下,HyperAttention使ChatGLM2的推理时间加快50%,而困惑度仅从5.6增至6.3。在更长上下文(如131k)并采用因果掩码时,HyperAttention在单层注意力上实现了5倍加速。