Context parallelism (CP) is essential for training large-scale, long-context language models, as it partitions sequences to reduce memory overhead. However, existing CP methods suffer from workload imbalance, inefficient kernels, and redundant communication due to static sequence sharding and key-value (KV) tensor communication. We present FlashCP, a load-balanced and communication-efficient framework for CP training. FlashCP introduces a sharding-aware communication mechanism to eliminate redundant KV communication and proposes a novel Whole-Doc sharding strategy that maximizes communication savings while maintaining balanced workloads. To efficiently combine Whole-Doc and Per-Doc sharding, FlashCP further designs a heuristic algorithm to search for near-optimal sharding plans. Extensive experiments show that FlashCP achieves up to 1.63x speedup over state-of-the-art CP frameworks across diverse datasets.
翻译:上下文并行(CP)是训练大规模长上下文语言模型的关键技术,通过序列分片降低内存开销。然而现有CP方法因静态序列分片和键值(KV)张量通信,存在负载不均衡、内核效率低下及冗余通信问题。本文提出FlashCP——一种负载均衡且通信高效的CP训练框架。FlashCP引入基于分片感知的通信机制消除冗余KV通信,并创新提出Whole-Doc分片策略,在保持负载均衡的同时最大化通信节省。为高效结合Whole-Doc与Per-Doc分片,FlashCP进一步设计启发式算法搜索近最优分片方案。大量实验表明,FlashCP在不同数据集上相较于最先进的CP框架可实现最高1.63倍加速。