FlashAttention (Dao, 2023) effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU. In this paper, we introduce DISTFLASHATTN, a distributed memory-efficient attention mechanism optimized for long-context LLMs training. We propose three key techniques: token-level workload balancing, overlapping key-value communication, and a rematerialization-aware gradient checkpointing algorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequence lengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 - 5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 - 2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.
翻译:FlashAttention(Dao, 2023)有效将基于Transformer的大语言模型(LLM)在单GPU上训练时的二次峰值内存占用降至线性。本文提出DISTFLASHATTN——一种针对长上下文LLM训练优化的分布式内存高效注意力机制。我们提出三项关键技术:词元级工作负载均衡、重叠键值通信以及重计算感知的梯度检查点算法。我们在Llama-7B及其变体上评估了DISTFLASHATTN,序列长度从32K到512K。相比Ring Self-Attention,DISTFLASHATTN支持长达8倍的序列并实现4.45-5.64倍加速;相比集成FlashAttention的Megatron-LM,支持2-8倍序列并实现1.24-2.01倍加速。与最新的Ring Attention及DeepSpeed-Ulysses相比,分别实现1.67倍和1.26-1.88倍加速。代码开源于https://github.com/RulinShao/LightSeq。