We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.
翻译:我们提出一种名为HyperAttention的近似注意力机制,以应对大语言模型(LLM)中长上下文复杂度增长带来的计算挑战。近期研究表明,在最坏情况下,除非注意力矩阵元素有界或矩阵具有低稳定秩,否则二次时间复杂度不可避免。我们引入两个度量参数:(1)归一化注意力矩阵中最大列范数;(2)检测并移除大项后非归一化注意力矩阵的行范数比率。通过这类细粒度参数刻画问题难度,我们突破了先前的下限——即使矩阵元素无界或具有高稳定秩,只要所述参数足够小,仍能实现线性时间采样算法。HyperAttention采用模块化设计,可便捷集成其他快速底层实现(特别是FlashAttention)。实验表明,利用局部敏感哈希(LSH)识别大项后,HyperAttention的性能优于现有方法,在速度上相较FlashAttention等先进方案有显著提升。我们在多种不同长上下文长度数据集上验证了HyperAttention的实际性能。例如,在32k上下文长度下,HyperAttention使ChatGLM2的推理速度提升50%,而困惑度从5.6增至6.3;在更大上下文长度(如131k)且采用因果掩码时,HyperAttention在单注意力层上实现了5倍加速。