We provide an optimized implementation of the forward pass of FlashAttention-2, a popular memory-aware scaled dot-product attention algorithm, as a custom fused CUDA kernel targeting NVIDIA Hopper architecture and written using the open-source CUTLASS library. In doing so, we explain the challenges and techniques involved in fusing online-softmax with back-to-back GEMM kernels, utilizing the Hopper-specific Tensor Memory Accelerator (TMA) and Warpgroup Matrix-Multiply-Accumulate (WGMMA) instructions, defining and transforming CUTLASS Layouts and Tensors, overlapping copy and GEMM operations, and choosing optimal tile sizes for the Q, K and V attention matrices while balancing the register pressure and shared memory utilization. In head-to-head benchmarks on a single H100 PCIe GPU for some common choices of hyperparameters, we observe 20-50% higher FLOPs/s over a version of FlashAttention-2 optimized for last-generation NVIDIA Ampere architecture.
翻译:我们提供了FlashAttention-2前向传播的优化实现——该算法是一种流行的内存感知型缩放点积注意力机制,通过定制化融合CUDA内核,针对NVIDIA Hopper架构设计,并基于开源CUTLASS库编写。在此过程中,我们阐释了将在线softmax与连续GEMM内核相融合所涉及的挑战与技术:包括利用Hopper架构特有的张量内存加速器(TMA)和线程束组矩阵乘累加(WGMMA)指令、定义与转换CUTLASS布局和张量、重叠拷贝与GEMM操作,以及在平衡寄存器压力与共享内存占用的前提下为Q、K、V注意力矩阵选择最优分块尺寸。在单个H100 PCIe GPU上针对常见超参数配置进行的对比基准测试中,相较于为上一代NVIDIA Ampere架构优化的FlashAttention-2版本,我们观测到每秒浮点运算次数(FLOPs/s)提升了20-50%。