Transformers have significantly advanced AI and machine learning through their powerful attention mechanism. However, computing attention on long sequences can become a computational bottleneck. FlashAttention mitigates this by fusing the softmax and matrix operations into a tiled computation pattern that decouples performance from sequence length. Though designed for GPUs, its simplicity also makes it well suited for direct hardware acceleration. To improve hardware implementation, we compute FlashAttention using a mixture of floating-point and fixed-point logarithm domain representations. Floating-point is used to compute attention scores from query and key matrices, while logarithmic computation simplifies the fused computation of softmax normalization and the multiplication with the value matrix. This transformation, called H-FA, replaces vector-wide floating-point multiplication and division operations by additions and subtractions implemented efficiently with fixed-point arithmetic in the logarithm domain. Exponential function evaluations are effectively omitted and fused with the rest operations, and the final result is directly returned to floating-point arithmetic without any additional hardware overhead. Hardware implementation results at 28nm demonstrate that H-FA achieves a 26.5% reduction in area and a 23.4% reduction in power, on average, compared to FlashAttention parallel hardware architectures built solely with floating-point datapaths, without hindering performance.
翻译:Transformer凭借其强大的注意力机制,显著推动了人工智能与机器学习的发展。然而,在长序列上计算注意力可能成为计算瓶颈。FlashAttention通过将softmax与矩阵运算融合为分块计算模式,使性能与序列长度解耦,从而缓解了这一问题。尽管该算法专为GPU设计,但其简洁性也使其非常适合直接进行硬件加速。为改进硬件实现,我们采用浮点与定点对数域表示相结合的方式计算FlashAttention。其中,浮点运算用于根据查询矩阵和键矩阵计算注意力分数,而对数计算则简化了softmax归一化与值矩阵乘法的融合运算。这种称为H-FA的转换方法,将对数域中通过定点算术高效实现的加法与减法,替代了向量级的浮点乘除运算。指数函数求值被有效省略并与其他运算融合,最终结果可直接返回至浮点运算而无需任何额外的硬件开销。28纳米工艺下的硬件实现结果表明,与完全采用浮点数据路径构建的FlashAttention并行硬件架构相比,H-FA在保持性能不受影响的前提下,平均实现了面积减少26.5%和功耗降低23.4%的效果。