We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike existing libraries, JPC leverages ordinary differential equation solvers to integrate the gradient flow inference dynamics of PCNs. We find that a second-order solver achieves significantly faster runtimes compared to standard Euler integration, with comparable performance on a range of tasks and network depths. JPC also provides some theoretical tools that can be used to study PCNs. We hope that JPC will facilitate future research of PC. The code is available at https://github.com/thebuckleylab/jpc.
翻译:本文介绍JPC——一个基于JAX框架、用于训练预测编码神经网络的库。JPC提供了简洁、快速且灵活的接口,可训练包括判别式、生成式及混合模型在内的各类预测编码网络。与现有工具库不同,JPC采用常微分方程求解器来集成PCNs的梯度流推理动态过程。研究发现,相较于标准欧拉积分法,二阶求解器在保持各类任务和网络深度性能相当的同时,能显著提升运行效率。JPC还提供若干可用于研究PCNs的理论工具。我们希望JPC能推动预测编码领域的后续研究。代码已开源:https://github.com/thebuckleylab/jpc。