We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions. We present a set of primitive operators that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX's internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool's efficacy and simplicity through applications where functional derivatives are indispensable. The source code of this work is released at https://github.com/sail-sg/autofd .
翻译:我们将JAX扩展为能够自动微分高阶函数(泛函与算子)。通过将函数表示为数组的推广形式,我们无缝利用JAX现有的原语系统实现高阶函数。本文提出一组作为构建基础的原始算子,用于构造若干关键类型的泛函。针对每个引入的原始算子,我们推导并实现了其线性化与转置规则,使之一致遵循JAX前向与反向模式自动微分的内部协议。这一增强使得用户能够以传统函数的语法进行泛函微分,所得的泛函梯度本身即为可在Python中调用的函数。我们通过泛函微分不可或缺的应用场景,展示了该工具的高效性与简洁性。本工作的源代码发布在https://github.com/sail-sg/autofd。