Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is $α$-entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer $τ$. In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute $τ$ to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2 matches or improves per-step training time relative to FlashAttention-2 when block sparsity is moderate-to-high (e.g., $>$60\%), which often occurs at long-context lengths. On downstream tasks, models trained with our efficient $α$-entmax attention match softmax baselines at short-context lengths and achieve substantial gains in long-context settings.
翻译:稀疏注意力被提出作为缓解Transformer二次成本(长上下文训练中的核心瓶颈)的一种方法。一个具有前景的研究方向是α-entmax注意力,它是softmax的一种可微分稀疏替代方案,能够实现输入相关的稀疏性,但由于计算归一化因子τ所需的额外计算开销,其性能尚落后于softmax。本文提出了AdaSplash-2,通过一种新颖的基于直方图的初始化方法解决了这一限制,将计算τ所需的迭代次数通常减少至1-2次。其关键思想是,在运行过程中实时计算注意力得分的粗粒度直方图,并将其存储在片上SRAM中,从而获得更准确的初始化,实现快速的前向和反向计算。结合一种感知稀疏性的GPU实现(能以低开销跳过零块),当块稀疏度处于中等至较高水平时(例如>60%,常见于长上下文长度场景),AdaSplash-2的单步训练时间与FlashAttention-2相比相当甚至更短。在下游任务中,使用我们高效的α-entmax注意力训练的模型,在短上下文长度下可匹配softmax基线,并在长上下文设置中取得显著性能提升。