Recent advances to algorithms for training spiking neural networks (SNNs) often leverage their unique dynamics. While backpropagation through time (BPTT) with surrogate gradients dominate the field, a rich landscape of alternatives can situate algorithms across various points in the performance, bio-plausibility, and complexity landscape. Evaluating and comparing algorithms is currently a cumbersome and error-prone process, requiring them to be repeatedly re-implemented. We introduce Slax, a JAX-based library designed to accelerate SNN algorithm design, compatible with the broader JAX and Flax ecosystem. Slax provides optimized implementations of diverse training algorithms, allowing direct performance comparison. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training.
翻译:近期针对脉冲神经网络(SNN)训练算法的进展常利用其独特动力学特性。尽管基于时间反向传播(BPTT)与替代梯度的算法主导该领域,但丰富的替代方案可分布于性能、生物合理性及复杂度等不同维度的算法空间中。当前评估与比较这些算法是一个繁琐且易出错的过程,需要反复重新实现。我们提出Slax——基于JAX的库,旨在加速SNN算法设计,并与JAX和Flax生态系统兼容。Slax提供多种训练算法的优化实现,支持直接性能比较。其工具包包含通过损失景观、梯度相似性及其他训练过程中模型行为指标来可视化和调试算法的方法。