BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
翻译:BlackJAX是一个实现了贝叶斯计算中常用采样与变分推断算法的库。该库采用函数式方法实现算法,旨在确保易用性、高性能和模块化。BlackJAX使用Python编写,通过JAX编译并运行类似NumPy的采样器和变分方法,支持CPU、GPU和TPU。该库直接操作(未归一化的)目标对数密度函数,因此能很好地与概率编程语言集成。BlackJAX旨在作为底层、可组合的基本统计“原子”的集合,这些原子可组合起来执行定义明确的贝叶斯推断,同时提供高级例程以方便使用。它面向需要前沿方法的用户、希望创建复杂采样方法的研究人员,以及希望了解这些方法工作原理的学习者。