Wasserstein Gradient Flow (WGF) describes the gradient dynamics of probability density within the Wasserstein space. WGF provides a promising approach for conducting optimization over the probability distributions. Numerically approximating the continuous WGF requires the time discretization method. The most well-known method for this is the JKO scheme. In this regard, previous WGF models employ the JKO scheme and parametrize transport map for each JKO step. However, this approach results in quadratic training complexity $O(K^2)$ with the number of JKO step $K$. This severely limits the scalability of WGF models. In this paper, we introduce a scalable WGF-based generative model, called Semi-dual JKO (S-JKO). Our model is based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport. Our approach reduces the training complexity to $O(K)$. We demonstrate that our model significantly outperforms existing WGF-based generative models, achieving FID scores of 2.62 on CIFAR-10 and 5.46 on CelebA-HQ-256, which are comparable to state-of-the-art image generative models.
翻译:Wasserstein梯度流描述了Wasserstein空间中概率密度的梯度动力学,为概率分布优化提供了有效方法。连续Wasserstein梯度流的数值逼近需要时间离散化方法,其中最著名的是JKO格式。现有WGF模型均采用JKO格式,并对每个JKO步骤的参数化传输映射进行建模。然而,该方法导致训练复杂度随JKO步数K呈二次增长O(K²),严重限制了WGF模型的可扩展性。本文提出一种可扩展的基于WGF的生成模型——半对偶JKO(S-JKO)。该模型基于JKO步骤的半对偶形式,该形式源自JKO步骤与非平衡最优传输的等价性。我们的方法将训练复杂度降低至O(K)。实验表明,该模型显著优于现有基于WGF的生成模型,在CIFAR-10数据集上FID得分为2.62,在CelebA-HQ-256数据集上为5.46,达到了与最先进图像生成模型相当的性能。