Compared to the wide array of advanced Monte Carlo methods supported by modern probabilistic programming languages (PPLs), PPL support for variational inference (VI) is less developed: users are typically limited to a predefined selection of variational objectives and gradient estimators, which are implemented monolithically (and without formal correctness arguments) in PPL backends. In this paper, we propose a more modular approach to supporting variational inference in PPLs, based on compositional program transformation. In our approach, variational objectives are expressed as programs, that may employ first-class constructs for computing densities of and expected values under user-defined models and variational families. We then transform these programs systematically into unbiased gradient estimators for optimizing the objectives they define. Our design enables modular reasoning about many interacting concerns, including automatic differentiation, density accumulation, tracing, and the application of unbiased gradient estimation strategies. Additionally, relative to existing support for VI in PPLs, our design increases expressiveness along three axes: (1) it supports an open-ended set of user-defined variational objectives, rather than a fixed menu of options; (2) it supports a combinatorial space of gradient estimation strategies, many not automated by today's PPLs; and (3) it supports a broader class of models and variational families, because it supports constructs for approximate marginalization and normalization (previously introduced only for Monte Carlo inference). We implement our approach in an extension to the Gen probabilistic programming system (genjax.vi, implemented in JAX), and evaluate on several deep generative modeling tasks, showing minimal performance overhead vs. hand-coded implementations and performance competitive with well-established open-source PPLs.
翻译:与现代概率编程语言(PPL)所支持的众多先进蒙特卡洛方法相比,PPL对变分推断(VI)的支持尚不完善:用户通常局限于预定义的变分目标与梯度估计器选择,这些功能在PPL后端以整体化方式实现(且缺乏形式化正确性论证)。本文提出一种基于组合式程序变换的、更模块化的PPL变分推断支持方案。在我们的方法中,变分目标被表达为可编程程序,这些程序能使用一等构造来计算用户定义模型与变分族下的密度及期望值。我们随后将这些程序系统性地转换为无偏梯度估计器,以优化其所定义的目标。我们的设计支持对自动微分、密度累积、追踪以及无偏梯度估计策略应用等多个交互关注点进行模块化推理。此外,相较于现有PPL中的VI支持方案,我们的设计在三个维度上提升了表达能力:(1)支持开放式的用户自定义变分目标集合,而非固定的选项菜单;(2)支持组合式梯度估计策略空间,其中许多策略未被当前PPL自动化实现;(3)支持更广泛的模型与变分族类别,因其支持近似边缘化与归一化构造(此前仅用于蒙特卡洛推断)。我们在Gen概率编程系统的扩展中实现了该方法(基于JAX实现的genjax.vi),并在若干深度生成建模任务上进行了评估,结果显示其性能开销相较于手工编码实现可忽略不计,且与成熟开源PPL相比具有竞争力。