Bringing PDEs to JAX with forward and reverse modes automatic differentiationPartial differential equations (PDEs) are used to describe a variety of
physical phenomena. Often these equations do not have analytical solutions and
numerical approximations are used instead. One of the common methods to solve
PDEs is the finite element method. Computing derivative information of the
solution with respect to the input parameters is important in many tasks in
scientific computing. We extend JAX automatic differentiation library with an
interface to Firedrake finite element library. High-level symbolic
representation of PDEs allows bypassing differentiating through low-level
possibly many iterations of the underlying nonlinear solvers. Differentiating
through Firedrake solvers is done using tangent-linear and adjoint equations.
This enables the efficient composition of finite element solvers with arbitrary
differentiable programs. The code is available at
github.com/IvanYashchuk/jax-firedrake.
arxiv.org