Backpropagation's main limitation is its need to store intermediate activations (residuals) during the forward pass, which restricts the depth of trainable networks. This raises a fundamental question: can we avoid storing these activations? We address this by revisiting the structure of gradient computation. Backpropagation computes gradients through a sequence of vector-Jacobian products, an operation that is generally irreversible. The lost information lies in the cokernel of each layer's Jacobian. We define submersive networks -- networks whose layer Jacobians have trivial cokernels -- in which gradients can be reconstructed exactly in a forward sweep without storing activations. For non-submersive layers, we introduce fragmental gradient checkpointing, which records only the minimal subset of residuals necessary to restore the cotangents erased by the Jacobian. Central to our approach is a novel operator, the vector-inverse-Jacobian product (vijp), which inverts gradient flow outside the cokernel. Our mixed-mode algorithm first computes input gradients with a memory-efficient reverse pass, then reconstructs parameter gradients in a forward sweep using the vijp, eliminating the need to store activations. We implement this method in Moonwalk and show that it matches backpropagation's runtime while training networks more than twice as deep under the same memory budget.
翻译:反向传播的主要限制在于,其在前向传播过程中需要存储中间激活值(残差),这限制了可训练网络的深度。由此引发一个根本性问题:我们能否避免存储这些激活值?为此,我们重新审视了梯度计算的结构。反向传播通过一系列向量-雅可比乘积运算计算梯度,而该运算通常是不可逆的。丢失的信息位于每层雅可比矩阵的余核中。我们定义了浸没式网络——即各层雅可比矩阵的余核均为平凡的网络——在此类网络中,无需存储激活值,即可通过一次前向扫描精确重构梯度。对于非浸没式层,我们引入了分段梯度检查点技术,仅记录恢复被雅可比矩阵擦除的余切向量所需的最少残差子集。该方法的核心在于一种新型算子——向量-逆雅可比乘积(vijp),其可在余核外部反转梯度流。我们的混合模式算法首先通过内存高效的反向传播计算输入梯度,随后利用vijp在前向扫描中重构参数梯度,从而消除了存储激活值的需求。我们在Moonwalk中实现了该方法,并证明其在相同内存预算下训练深度超过两倍的网络时,运行时间与反向传播相当。