Particle-based variational inference methods (ParVIs) such as Stein variational gradient descent (SVGD) update the particles based on the kernelized Wasserstein gradient flow for the Kullback-Leibler (KL) divergence. However, the design of kernels is often non-trivial and can be restrictive for the flexibility of the method. Recent works show that functional gradient flow approximations with quadratic form regularization terms can improve performance. In this paper, we propose a ParVI framework, called generalized Wasserstein gradient descent (GWG), based on a generalized Wasserstein gradient flow of the KL divergence, which can be viewed as a functional gradient method with a broader class of regularizers induced by convex functions. We show that GWG exhibits strong convergence guarantees. We also provide an adaptive version that automatically chooses Wasserstein metric to accelerate convergence. In experiments, we demonstrate the effectiveness and efficiency of the proposed framework on both simulated and real data problems.
翻译:基于粒子的变分推断方法(如Stein变分梯度下降SVGD)基于Kullback-Leibler(KL)散度的核化Wasserstein梯度流来更新粒子。然而,核函数的设计往往具有非平凡性,且可能限制方法的灵活性。近期研究表明,引入二次型正则化项的函数梯度流近似方法能够提升性能。本文提出一种名为广义Wasserstein梯度下降(GWG)的粒子变分推断框架,该方法基于KL散度的广义Wasserstein梯度流,可视为采用由凸函数诱导的更广类正则化项的函数梯度方法。我们证明了GWG具有强收敛保证,并进一步提出自适应版本,能够自动选择Wasserstein度量以加速收敛。在实验中,我们通过模拟数据和真实数据问题验证了所提框架的有效性与高效性。