Optimal transport theory has provided machine learning with several tools to infer a push-forward map between densities from samples. While this theory has recently seen tremendous methodological developments in machine learning, its practical implementation remains notoriously difficult, because it is plagued by both computational and statistical challenges. Because of such difficulties, existing approaches rarely depart from the default choice of estimating such maps with the simple squared-Euclidean distance as the ground cost, $c(x,y)=\|x-y\|^2_2$. We follow a different path in this work, with the motivation of \emph{learning} a suitable cost structure to encourage maps to transport points along engineered features. We extend the recently proposed Monge-Bregman-Occam pipeline~\citep{cuturi2023monge}, that rests on an alternative cost formulation that is also cost-invariant $c(x,y)=h(x-y)$, but which adopts a more general form as $h=\tfrac12 \ell_2^2+\tau$, where $\tau$ is an appropriately chosen regularizer. We first propose a method that builds upon proximal gradient descent to generate ground truth transports for such structured costs, using the notion of $h$-transforms and $h$-concave potentials. We show more generally that such a method can be extended to compute $h$-transforms for entropic potentials. We study a regularizer that promotes transport displacements in low-dimensional spaces, and propose to learn such a basis change using Riemannian gradient descent on the Stiefel manifold. We show that these changes lead to estimators that are more robust and easier to interpret.
翻译:最优传输理论为机器学习提供了多种工具,用于从样本中推断密度间的推前映射。尽管该理论近年来在机器学习方法论上取得了显著进展,但其实际实现仍以困难著称,这主要源于计算和统计两方面的挑战。由于这些困难,现有方法很少偏离默认选择,即使用简单的平方欧氏距离作为基础成本 $c(x,y)=\|x-y\|^2_2$ 来估计此类映射。本文另辟蹊径,旨在\emph{学习}合适的成本结构,以鼓励映射沿工程化特征传输点。我们扩展了近期提出的蒙日-布雷格曼-奥卡姆流程~\citep{cuturi2023monge},该流程基于一种同样具有成本不变性 $c(x,y)=h(x-y)$ 的替代成本公式,但采用更一般的形式 $h=\tfrac12 \ell_2^2+\tau$,其中 $\tau$ 是适当选择的正则化项。我们首先提出一种基于近端梯度下降的方法,利用 $h$-变换和 $h$-凹势的概念,为此类结构化成本生成真实传输结果。我们更一般地证明了该方法可扩展至计算熵势的 $h$-变换。我们研究了一种促进低维空间中传输位移的正则化项,并提出使用斯提费尔流形上的黎曼梯度下降来学习此类基变换。实验表明,这些改进使得估计器更加稳健且更易于解释。