Learning distance functions between complex objects, such as the Wasserstein distance to compare point sets, is a common goal in machine learning applications. However, functions on such complex objects (e.g., point sets and graphs) are often required to be invariant to a wide variety of group actions e.g. permutation or rigid transformation. Therefore, continuous and symmetric product functions (such as distance functions) on such complex objects must also be invariant to the product of such group actions. We call these functions symmetric and factor-wise group invariant (or SFGI functions in short). In this paper, we first present a general neural network architecture for approximating SFGI functions. The main contribution of this paper combines this general neural network with a sketching idea to develop a specific and efficient neural network which can approximate the $p$-th Wasserstein distance between point sets. Very importantly, the required model complexity is independent of the sizes of input point sets. On the theoretical front, to the best of our knowledge, this is the first result showing that there exists a neural network with the capacity to approximate Wasserstein distance with bounded model complexity. Our work provides an interesting integration of sketching ideas for geometric problems with universal approximation of symmetric functions. On the empirical front, we present a range of results showing that our newly proposed neural network architecture performs comparatively or better than other models (including a SOTA Siamese Autoencoder based approach). In particular, our neural network generalizes significantly better and trains much faster than the SOTA Siamese AE. Finally, this line of investigation could be useful in exploring effective neural network design for solving a broad range of geometric optimization problems (e.g., $k$-means in a metric space).
翻译:学习复杂对象之间的距离函数(例如用于比较点集的Wasserstein距离)是机器学习应用中的常见目标。然而,作用于此类复杂对象(如点集和图)的函数通常需要满足多种群作用下的不变性,例如置换或刚体变换。因此,作用于此类复杂对象的连续对称乘积函数(如距离函数)也必须满足此类群作用乘积下的不变性。我们将这些函数称为对称与因子群不变函数(简称SFGI函数)。本文首先提出一种通用的神经网络架构用于逼近SFGI函数。本文的主要贡献在于将这一通用神经网络与草图化思想相结合,开发出一种特定且高效的神经网络,能够逼近点集之间的$p$阶Wasserstein距离。极为重要的是,所需模型复杂度与输入点集的大小无关。在理论层面,据我们所知,这是首次证明存在具有有界模型复杂度且能逼近Wasserstein距离的神经网络。我们的工作为几何问题的草图化思想与对称函数通用逼近提供了有趣的整合。在实验层面,我们展示了一系列结果,表明我们新提出的神经网络架构在性能上与其他模型(包括基于最先进孪生自编码器的方法)相当或更优。特别是,我们的神经网络具有显著更强的泛化能力,且训练速度远快于最先进的孪生自编码器。最后,这一研究方向可能有助于探索解决广泛几何优化问题(例如度量空间中的$k$-均值聚类)的有效神经网络设计。