We present FAX, a JAX-based library designed to support large-scale distributed and federated computations in both data center and cross-device applications. FAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. FAX embeds building blocks for federated computations as primitives in JAX. This enables three key benefits. First, FAX computations can be translated to XLA HLO. Second, FAX provides a full implementation of federated automatic differentiation, greatly simplifying the expression of federated computations. Last, FAX computations can be interpreted out to existing production cross-device federated compute systems. We show that FAX provides an easily programmable, performant, and scalable framework for federated computations in the data center. FAX is available at https://github.com/google-research/google-research/tree/master/fax .
翻译:我们提出了FAX,一个基于JAX的库,旨在支持数据中心和跨设备应用中大规模分布式与联邦计算。FAX利用JAX的分片机制,原生支持TPU及最先进的JAX运行时(包括Pathways)。FAX将联邦计算的构建模块嵌入为JAX的原语,从而带来三个关键优势:首先,FAX计算可转换为XLA HLO;其次,FAX提供了联邦自动微分的完整实现,极大简化了联邦计算的表达;最后,FAX计算可被解释为现有的生产级跨设备联邦计算系统。实验表明,FAX为数据中心的联邦计算提供了一个易编程、高性能且可扩展的框架。FAX代码开源地址:https://github.com/google-research/google-research/tree/master/fax 。