JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code requires deep knowledge of JAX and its C++ backend. The goal of JAXbind is to drastically reduce the effort required to bind custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom so-called JAX primitives that support arbitrary JAX transformations.
翻译:JAX广泛应用于机器学习和科学计算,后者通常依赖现有的高性能代码,而我们理想中希望将这些代码整合到JAX中。在JAX中重新实现现有代码往往不切实际,而JAX中用于绑定自定义代码的现有接口需要对JAX及其C++后端有深入理解。JAXbind的目标是大幅减少将其他编程语言实现的自定义函数绑定到JAX所需的工作量。具体而言,JAXbind提供了一个易用的Python接口,用于定义支持任意JAX变换的自定义所谓JAX原语。