This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
翻译:本文介绍了JaxPruner,一个基于JAX的开源剪枝与稀疏训练库,专为机器学习研究设计。JaxPruner通过提供流行剪枝和稀疏训练算法的简洁实现,并以最小的内存和延迟开销为代价,旨在加速稀疏神经网络的研究。该库中实现的算法采用通用API,并与流行的优化库Optax无缝协作,从而便于与现有基于JAX的库进行集成。我们通过在四个不同的代码库中(Scenic、t5x、Dopamine和FedJAX)提供示例,展示了这种集成的便捷性,并在主流基准测试上提供了基线实验。