We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.
翻译:我们提出了JaxPP系统,该系统通过灵活的流水线并行机制高效扩展大规模深度学习模型的训练。我们引入了一种无缝编程模型,允许用户为梯度累积实现自定义的流水线调度策略。JaxPP能自动将对应流水线阶段的任务分配到计算集群节点上,并自动推断节点间的通信模式。我们实现了支持异步执行SPMD任务的MPMD运行时系统。实验表明,相较于性能最优的SPMD配置方案,JaxPP的流水线并行实现将硬件利用率最高提升了$1.11\times$。