Current hierarchical attention methods, such as NSA and InfLLMv2, select the top-k relevant key-value (KV) blocks based on coarse attention scores and subsequently apply fine-grained softmax attention on the selected tokens. However, the top-k operation assumes the number of relevant tokens for any query is fixed and it precludes the gradient flow between the sparse and dense stages. In this work, we propose DashAttention (Differentiable and Adaptive Sparse Hierarchical Attention), which leverages the adaptively sparse $α$-entmax transformation to select a variable number of blocks according to the current query in the first stage. This in turn provides a prior for the second-stage softmax attention, keeping the entire hierarchy fully differentiable. Contrary to other hierarchical attention methods, we show that DashAttention is non-dispersive, translating to better long-context modeling ability. Experiments with large language models (LLMs) show that DashAttention achieves comparable accuracy as full attention with 75% sparsity and a better Pareto frontier than NSA and InfLLMv2, especially in high-sparsity regimes. We also provide an efficient, GPU-aware implementation of DashAttention in Triton, which achieves a speedup of up to over FlashAttention-3 at inference time. Overall, DashAttention offers a cost-effective strategy to model long contexts.
翻译:当前的层次注意力方法(如NSA和InfLLMv2)基于粗粒度注意力得分选择top-k相关键值块,随后对所选令牌应用细粒度softmax注意力。然而,top-k操作假设任意查询的相关令牌数量固定,且阻碍了稀疏阶段与密集阶段之间的梯度流动。为此,我们提出DashAttention(可微分自适应稀疏层次注意力),该方法在第一阶段利用自适应稀疏的α-entmax变换根据当前查询选择数量可变的块,从而为第二阶段的softmax注意力提供先验信息,并保持整个层次结构的完全可微性。与其他层次注意力方法不同,我们证明DashAttention具有非分散性,这转化为更优的长上下文建模能力。在大语言模型上的实验表明,DashAttention在75%稀疏度下实现了与全注意力相当的精度,并在帕累托前沿上优于NSA和InfLLMv2,尤其是在高稀疏度场景中。我们还提供了基于Triton的高效GPU感知实现,该实现推理速度相比FlashAttention-3最高提升数倍。综上,DashAttention为长上下文建模提供了一种经济高效的策略。