Attention is a fundamental building block of large language models (LLMs), so there have been many efforts to implement it efficiently. For example, FlashAttention leverages tiling and kernel fusion to optimize attention. Recently, a number of variants of attention have been introduced to enhance model quality or efficiency. Supporting them efficiently remains difficult since they usually require specialized kernels or hand-tuned implementations. FlexAttention recently addressed part of this gap by using static programming templates to support FlashAttention-like kernels for a subset of attention variants. In this paper, we introduce Flashlight, a compiler-native framework within the PyTorch ecosystem that automatically generates fused, FlashAttention-style kernels for arbitrary attention-based programs, without relying on static templates or predefined kernel specializations. Flashlight leverages PyTorch's compilation workflow to fuse and tile attention computations transparently, enabling efficient execution for diverse attention patterns. Not only does it support all variants expressible in the FlexAttention model but it also handles more general, data-dependent attention formulations that are beyond the capabilities of FlexAttention. Our results show that Flashlight produces kernels with competitive or superior performance to FlexAttention, while offering the flexibility of native PyTorch code, enabling developers to rapidly explore new attention models without sacrificing performance.
翻译:注意力机制是大语言模型(LLM)的基础构建模块,因此已有大量研究致力于其高效实现。例如,FlashAttention利用分块和内核融合技术优化注意力计算。近年来,为提升模型质量或效率,学界提出了多种注意力变体。但由于这些变体通常需要专用内核或手工调优实现,高效支持它们仍面临挑战。FlexAttention最近通过使用静态编程模板,为部分注意力变体提供了类似FlashAttention的内核支持,部分解决了这一难题。本文提出Flashlight——PyTorch生态系统中一个编译器原生框架,它能自动为任意基于注意力机制的程序生成融合的FlashAttention风格内核,无需依赖静态模板或预定义内核特化。Flashlight利用PyTorch编译流程透明地融合和分块注意力计算,为多样化的注意力模式实现高效执行。该框架不仅支持FlexAttention模型可表达的所有注意力变体,还能处理超出FlexAttention能力的、更通用的数据驱动注意力公式。实验结果表明,Flashlight生成的内核性能与FlexAttention相当或更优,同时保留原生PyTorch代码的灵活性,使开发者能够在无需牺牲性能的前提下快速探索新型注意力模型。