Optimal transport (OT) has profoundly impacted machine learning by providing theoretical and computational tools to realign datasets. In this context, given two large point clouds of sizes $n$ and $m$ in $\mathbb{R}^d$, entropic OT (EOT) solvers have emerged as the most reliable tool to either solve the Kantorovich problem and output a $n\times m$ coupling matrix, or to solve the Monge problem and learn a vector-valued push-forward map. While the robustness of EOT couplings/maps makes them a go-to choice in practical applications, EOT solvers remain difficult to tune because of a small but influential set of hyperparameters, notably the omnipresent entropic regularization strength $\varepsilon$. Setting $\varepsilon$ can be difficult, as it simultaneously impacts various performance metrics, such as compute speed, statistical performance, generalization, and bias. In this work, we propose a new class of EOT solvers (ProgOT), that can estimate both plans and transport maps. We take advantage of several opportunities to optimize the computation of EOT solutions by dividing mass displacement using a time discretization, borrowing inspiration from dynamic OT formulations, and conquering each of these steps using EOT with properly scheduled parameters. We provide experimental evidence demonstrating that ProgOT is a faster and more robust alternative to standard solvers when computing couplings at large scales, even outperforming neural network-based approaches. We also prove statistical consistency of our approach for estimating optimal transport maps.
翻译:最优传输(OT)通过提供理论及计算工具以重对齐数据集,对机器学习产生了深远影响。在此背景下,给定$\mathbb{R}^d$中规模分别为$n$和$m$的两个大型点云,熵最优传输(EOT)求解器已成为最可靠的工具,既可求解Kantorovich问题并输出$n\times m$耦合矩阵,亦可求解Monge问题并学习向量值前推映射。尽管EOT耦合/映射的鲁棒性使其成为实际应用中的首选,但EOT求解器因一组数量虽少却影响重大的超参数(尤其是无处不在的熵正则化强度$\varepsilon$)而仍难以调优。设定$\varepsilon$可能较为困难,因其同时影响多种性能指标,例如计算速度、统计性能、泛化能力与偏差。本工作中,我们提出一类新的EOT求解器(ProgOT),能够同时估计传输方案与传输映射。我们通过以下方式优化EOT解的计算:借鉴动态OT公式的灵感,采用时间离散化分割质量位移过程,并运用具有合理参数调度机制的EOT方法逐步攻克每个计算阶段。实验证据表明,在大规模耦合计算中,ProgOT相较于标准求解器具有更快的计算速度和更强的鲁棒性,其性能甚至优于基于神经网络的方法。我们还证明了该方法在估计最优传输映射方面的统计一致性。