We show how the basic Combinatory Homomorphic Automatic Differentiation (CHAD) algorithm can be optimised, using well-known methods, to yield a simple and generally applicable reverse-mode automatic differentiation (AD) technique that has the correct computational complexity that we would expect of a reverse AD algorithm. Specifically, we show that the standard optimisations of sparse vectors and state-passing style code (as well as defunctionalisation/closure conversion, for higher-order languages) give us a purely functional algorithm that is most of the way to the correct complexity, with (functional) mutable updates taking care of the final log-factors. We provide an Agda formalisation of our complexity proof. Finally, we discuss how the techniques apply to differentiating parallel functional programs: the key observations are 1) that all required mutability is (commutative, associative) accumulation, which lets us preserve task-parallelism and 2) that we can write down data-parallel derivatives for most data-parallel array primitives.
翻译:我们展示了如何利用众所周知的方法优化基本的组合同态自动微分(CHAD)算法,从而获得一种简单且广泛适用的反向模式自动微分(AD)技术,该技术具有反向AD算法应有的正确计算复杂度。具体而言,我们证明了稀疏向量和状态传递风格代码的标准优化(以及高阶语言的去函数化/闭包转换)能为我们提供一种纯函数式算法,该算法在大部分程度上已接近正确的复杂度,而(函数式)可变更新则负责处理最终的log因子。我们提供了复杂度证明的Agda形式化。最后,我们讨论了这些技术如何应用于区分并行函数式程序:关键观察是1)所有必需的可变性都是(可交换、可结合的)累加,这使我们能够保留任务并行性;2)我们可以为大多数数据并行数组原语写出数据并行导数。