We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
翻译:本文提出DrJAX,这是一个基于JAX的库,旨在支持采用MapReduce风格操作的大规模分布式并行机器学习算法。DrJAX利用JAX的分片机制,能够原生适配TPU及包括Pathways在内的前沿JAX运行时环境。该库将MapReduce计算的基本构建模块作为原语嵌入JAX框架,从而带来三大核心优势:首先,DrJAX计算可直接编译为XLA HLO,实现与多种机器学习训练平台的灵活集成;其次,所有计算过程均支持完全微分;最后,计算流程可无缝对接现有批处理计算系统,包括Apache Beam等传统MapReduce系统以及支撑联邦学习应用的跨设备计算系统。实验表明,DrJAX为并行算法开发提供了易于编程、高性能且可扩展的框架。DrJAX项目已开源,访问地址为\url{https://github.com/google-research/google-research/tree/master/drjax}。