Wasserstein Gradient Flow (WGF) describes the gradient dynamics of probability density within the Wasserstein space. WGF provides a promising approach for conducting optimization over the probability distributions. Numerically approximating the continuous WGF requires the time discretization method. The most well-known method for this is the JKO scheme. In this regard, previous WGF models employ the JKO scheme and parametrize transport map for each JKO step. However, this approach results in quadratic training complexity $O(K^2)$ with the number of JKO step $K$. This severely limits the scalability of WGF models. In this paper, we introduce a scalable WGF-based generative model, called Semi-dual JKO (S-JKO). Our model is based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport. Our approach reduces the training complexity to $O(K)$. We demonstrate that our model significantly outperforms existing WGF-based generative models, achieving FID scores of 2.62 on CIFAR-10 and 6.19 on CelebA-HQ-256, which are comparable to state-of-the-art image generative models.
翻译:Wasserstein梯度流描述了Wasserstein空间中概率密度的梯度动力学,为概率分布优化提供了有效途径。对连续WGF进行数值逼近需要时间离散化方法,其中最著名的为JKO格式。先前WGF模型采用JKO格式并为每个JKO步骤参数化传输映射,但该方法导致训练复杂度与JKO步数K呈二次关系O(K²),严重限制了模型的可扩展性。本文提出可扩展的WGF生成模型——半对偶JKO。该模型基于JKO步与非平衡最优输运的等价性推导出的半对偶形式,将训练复杂度降低至O(K)。实验表明,本模型显著优于现有WGF生成模型,在CIFAR-10和CelebA-HQ-256上分别达到2.62和6.19的FID分数,可媲美最先进的图像生成模型。