We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.
翻译:我们提出了一种名为HyperAttention的近似注意力机制,旨在应对大语言模型(LLMs)中日益复杂的长上下文所带来的计算挑战。近期研究表明,在最坏情况下,除非注意力矩阵的元素有界或矩阵具有低稳定秩,否则二次时间复杂度的计算是必要的。我们引入了两个参数,分别衡量:(1) 归一化注意力矩阵中的最大列范数,(2) 在检测并移除大条目后未归一化注意力矩阵的行范数比率。我们利用这些细粒度参数来捕捉问题的困难程度。尽管此前存在下界约束,但若上述参数较小,我们仍能在矩阵元素无界或稳定秩较高的情况下实现线性时间采样算法。HyperAttention采用模块化设计,可轻松集成其他快速底层实现(特别是FlashAttention)。实验表明,通过使用局部敏感哈希(LSH)识别大条目,HyperAttention的性能优于现有方法,相较于FlashAttention等最先进解决方案,显著提升了速度。我们在一系列不同长上下文长度的数据集上验证了HyperAttention的实证性能。例如,在32k上下文长度下,HyperAttention使ChatGLM2的推理时间缩短50%,同时困惑度从5.6升至6.3。在更长的上下文长度(如131k)中,结合因果掩码,HyperAttention在单个注意力层上实现了5倍加速。