JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code either limits the user to a single Jacobian product or requires deep knowledge of JAX and its C++ backend for general Jacobian products. With JAXbind we drastically reduce the effort required to bind custom functions implemented in other programming languages with full support for Jacobian-vector products and vector-Jacobian products to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows a user to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive.
翻译:JAX在机器学习和科学计算领域得到广泛应用,而科学计算通常依赖于现有的高性能代码,我们期望能将其整合到JAX中。用JAX重新实现现有代码往往不切实际,且JAX现有的自定义代码绑定接口要么将用户限制在单一雅可比积运算,要么要求用户对JAX及其C++后端有深入理解才能实现通用雅可比积运算。通过JAXbind,我们大幅降低了将其他编程语言实现的定制函数绑定至JAX的工作量,并全面支持雅可比-向量积和向量-雅可比积运算。具体而言,JAXbind提供了易于使用的Python接口,用于定义自定义的所谓JAX基元。通过JAXbind,任何可从Python调用的函数都能作为JAX基元公开。JAXbind允许用户通过自定义导数与批处理规则连接JAX函数转换引擎,从而为定制基元启用所有JAX变换功能。