Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. While FlashAttention-3 optimized attention for Hopper GPUs through asynchronous execution and warp specialization, it primarily targets the H100 architecture. The AI industry has rapidly transitioned to deploying Blackwell-based systems such as the B200 and GB200, which exhibit fundamentally different performance characteristics due to asymmetric hardware scaling: tensor core throughput doubles while other functional units (shared memory bandwidth, exponential units) scale more slowly or remain unchanged. We develop several techniques to address these shifting bottlenecks on Blackwell GPUs: (1) redesigned pipelines that exploit fully asynchronous MMA operations and larger tile sizes, (2) software-emulated exponential and conditional softmax rescaling that reduces non-matmul operations, and (3) leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass. We demonstrate that our method, FlashAttention-4, achieves up to 1.3$\times$ speedup over cuDNN 9.13 and 2.7$\times$ over Triton on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71% utilization). Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python, achieving 20-30$\times$ faster compile times compared to traditional C++ template-based approaches while maintaining full expressivity.
翻译:注意力机制作为普遍存在的Transformer架构的核心层,是大型语言模型和长上下文应用的性能瓶颈。尽管FlashAttention-3通过异步执行和线程束(warp)专用化针对Hopper GPU优化了注意力计算,但它主要面向H100架构。AI行业已迅速转向部署基于Blackwell架构的系统(如B200和GB200),这些系统由于非对称硬件扩展而展现出根本不同的性能特征:张量核心吞吐量翻倍,而其他功能单元(共享内存带宽、指数运算单元)扩展较慢或保持不变。我们开发了多种技术以应对Blackwell GPU上这些不断变化的瓶颈:(1)利用完全异步的MMA(矩阵乘累加)操作和更大分块尺寸的重新设计流水线;(2)通过软件模拟的指数运算和条件性softmax重缩放,以减少非矩阵乘法运算;(3)利用张量内存和2-CTA MMA模式,以减少反向传播过程中的共享内存流量和原子加法操作。我们证明,在B200 GPU上使用BF16精度时,我们的方法FlashAttention-4相比cuDNN 9.13实现了最高1.3倍的加速,相比Triton实现了最高2.7倍的加速,算力峰值达到1613 TFLOPs/s(利用率为71%)。除了算法创新,我们完全使用嵌入在Python中的CuTe-DSL实现了FlashAttention-4,相比传统的基于C++模板的方法,编译时间加快了20-30倍,同时保持了完整的表达能力。