Partial differential equations (PDEs) are used to describe a variety of physical phenomena. Often these equations do not have analytical solutions and numerical approximations are used instead. One of the common methods to solve PDEs is the finite element method. Computing derivative information of the solution with respect to the input parameters is important in many tasks in scientific computing. We extend JAX automatic differentiation library with an interface to Firedrake finite element library. High-level symbolic representation of PDEs allows bypassing differentiating through low-level possibly many iterations of the underlying nonlinear solvers. Differentiating through Firedrake solvers is done using tangent-linear and adjoint equations. This enables the efficient composition of finite element solvers with arbitrary differentiable programs. The code is available at github.com/IvanYashchuk/jax-firedrake.
翻译:偏微分方程(PDE)用于描述多种物理现象。这些方程通常没有解析解,因此需要借助数值近似方法。求解PDE的常用方法之一是有限元法。在科学计算的众多任务中,计算解对输入参数的导数信息至关重要。我们扩展了JAX自动微分库,为其提供了与Firedrake有限元库的接口。PDE的高层符号化表示避免了通过底层非线性求解器可能的多次迭代进行微分的过程。对Firedrake求解器的微分通过切线线性方程和伴随方程实现。这使得有限元求解器能够与任意可微分程序高效组合。相关代码发布在github.com/IvanYashchuk/jax-firedrake。