In a landscape of high-performance distributed ML systems, JAX has emerged as a framework of choice. However, JAX's modular design philosophy leaves it without a standardized checkpointing solution. In this paper, we introduce Orbax, a modular, JAX-native checkpointing library that abstracts the complexities of distributed accelerator systems while also providing flexibility for user-friendly checkpoint manipulations throughout the ML model lifecycle. We demonstrate performance exceeding comparable PyTorch competitors by up to 3.5$\times$ for saving and 2$\times$ for loading. The library is available at https://github.com/google/orbax.
翻译:在高性能分布式机器学习系统的领域中,JAX已成为一个备受青睐的框架。然而,JAX的模块化设计理念使其缺乏标准化的检查点解决方案。本文介绍Orbax——一个模块化、原生JAX的检查点库,它在抽象分布式加速器系统复杂性的同时,为机器学习模型全生命周期中的用户友好型检查点操作提供了灵活性。我们展示了其性能:与同类PyTorch竞品相比,保存速度最高提升3.5倍,加载速度最高提升2倍。该库的代码托管于https://github.com/google/orbax。