We propose a conditional stochastic interpolation (CSI) approach to learning conditional distributions. CSI learns probability flow equations or stochastic differential equations that transport a reference distribution to the target conditional distribution. This is achieved by first learning the drift function and the conditional score function based on conditional stochastic interpolation, which are then used to construct a deterministic process governed by an ordinary differential equation or a diffusion process for conditional sampling. In our proposed CSI model, we incorporate an adaptive diffusion term to address the instability issues arising during the training process. We provide explicit forms of the conditional score function and the drift function in terms of conditional expectations under mild conditions, which naturally lead to an nonparametric regression approach to estimating these functions. Furthermore, we establish non-asymptotic error bounds for learning the target conditional distribution via conditional stochastic interpolation in terms of KL divergence, taking into account the neural network approximation error. We illustrate the application of CSI on image generation using a benchmark image dataset.
翻译:我们提出了一种条件随机插值(CSI)方法,用于学习条件分布。CSI通过学习概率流方程或随机微分方程,将参考分布变换为目标条件分布。该过程首先基于条件随机插值学习漂移函数和条件分数函数,进而利用这些函数构建确定性常微分方程过程或扩散过程,以实现条件采样。在提出的CSI模型中,我们引入了自适应扩散项,以解决训练过程中出现的不稳定性问题。在温和条件下,我们给出了条件分数函数和漂移函数关于条件期望的显式形式,这自然导出了一种非参数回归方法用于估计这些函数。进一步地,我们基于KL散度建立了通过条件随机插值学习目标条件分布的非渐近误差界,并考虑了神经网络逼近误差。我们通过基准图像数据集展示了CSI在图像生成中的应用。