Neural ordinary differential equations (neural ODEs) have emerged as a novel network architecture that bridges dynamical systems and deep learning. However, the gradient obtained with the continuous adjoint method in the vanilla neural ODE is not reverse-accurate. Other approaches suffer either from an excessive memory requirement due to deep computational graphs or from limited choices for the time integration scheme, hampering their application to large-scale complex dynamical systems. To achieve accurate gradients without compromising memory efficiency and flexibility, we present a new neural ODE framework, PNODE, based on high-level discrete adjoint algorithmic differentiation. By leveraging discrete adjoint time integrators and advanced checkpointing strategies tailored for these integrators, PNODE can provide a balance between memory and computational costs, while computing the gradients consistently and accurately. We provide an open-source implementation based on PyTorch and PETSc, one of the most commonly used portable, scalable scientific computing libraries. We demonstrate the performance through extensive numerical experiments on image classification and continuous normalizing flow problems. We show that PNODE achieves the highest memory efficiency when compared with other reverse-accurate methods. On the image classification problems, PNODE is up to two times faster than the vanilla neural ODE and up to 2.3 times faster than the best existing reverse-accurate method. We also show that PNODE enables the use of the implicit time integration methods that are needed for stiff dynamical systems.
翻译:神经常微分方程(neural ODEs)已成为一种连接动力系统与深度学习的新型网络架构。然而,原始神经ODE中采用连续伴随方法获得的梯度并非逆向精确。其他方法要么因深度计算图导致内存需求过高,要么对时间积分方案的选择存在限制,从而阻碍其在大型复杂动力系统中的应用。为实现精确梯度且不牺牲内存效率与灵活性,我们提出了一种基于高阶离散伴随算法微分的新型神经ODE框架PNODE。通过利用离散伴随时间积分器以及针对这些积分器定制的高级检查点策略,PNODE能够在一致且精确计算梯度的同时,平衡内存与计算成本。我们基于PyTorch和PETSc(最常用的可移植可扩展科学计算库之一)提供了开源实现。通过图像分类和连续正则化流问题的广泛数值实验,我们展示了其性能表现。结果表明,PNODE与其他逆向精确方法相比实现了最高内存效率。在图像分类问题上,PNODE的速度比原始神经ODE快两倍,比现有最优逆向精确方法快2.3倍。我们还证明,PNODE支持刚性动力系统所需的隐式时间积分方法。