We present a new approach for Neural Optimal Transport (NOT) training procedure, capable of accurately and efficiently estimating optimal transportation plan via specific regularization on dual Kantorovich potentials. The main bottleneck of existing NOT solvers is associated with the procedure of finding a near-exact approximation of the conjugate operator (i.e., the c-transform), which is done either by optimizing over non-convex max-min objectives or by the computationally intensive fine-tuning of the initial approximated prediction. We resolve both issues by proposing a new, theoretically justified loss in the form of expectile regularisation which enforces binding conditions on the learning process of dual potentials. Such a regularization provides the upper bound estimation over the distribution of possible conjugate potentials and makes the learning stable, completely eliminating the need for additional extensive fine-tuning. Proposed method, called Expectile-Regularised Neural Optimal Transport (ENOT), outperforms previous state-of-the-art approaches on the established Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime). Moreover, we showcase performance of ENOT for varying cost functions on different tasks such as image generation, showing robustness of proposed algorithm. OTT-JAX library includes our implementation of ENOT algorithm https://ott-jax.readthedocs.io/en/latest/tutorials/ENOT.html
翻译:本文提出一种神经最优传输训练新方法,通过在对偶康托洛维奇势函数上施加特定正则化,能够快速精确地估计最优传输方案。现有神经最优传输求解器的主要瓶颈在于寻找共轭算子(即c变换)的近似精确逼近,该过程通常需要优化非凸的极大极小目标函数,或对初始近似预测进行计算密集型的精细调优。我们通过提出一种理论上可证明的期望分位数正则化损失函数,同时解决了这两个问题。该正则化强制对偶势函数学习过程满足约束条件,为可能的共轭势函数分布提供上界估计,使学习过程保持稳定,完全消除了额外大量精细调优的需求。所提出的方法称为期望分位数正则化神经最优传输,在既有的Wasserstein-2基准任务上以显著优势超越先前最优方法(质量提升最高达3倍,运行时间缩短最高达10倍)。此外,我们在图像生成等任务中展示了该方法对不同代价函数的适应性,证明了所提算法的鲁棒性。OTT-JAX库已集成ENOT算法实现https://ott-jax.readthedocs.io/en/latest/tutorials/ENOT.html