The learning of Gaussian Mixture Models (also referred to simply as GMMs) plays an important role in machine learning. Known for their expressiveness and interpretability, Gaussian mixture models have a wide range of applications, from statistics, computer vision to distributional reinforcement learning. However, as of today, few known algorithms can fit or learn these models, some of which include Expectation-Maximization algorithms and Sliced Wasserstein Distance. Even fewer algorithms are compatible with gradient descent, the common learning process for neural networks. In this paper, we derive a closed formula of two GMMs in the univariate, one-dimensional case, then propose a distance function called Sliced Cram\'er 2-distance for learning general multivariate GMMs. Our approach has several advantages over many previous methods. First, it has a closed-form expression for the univariate case and is easy to compute and implement using common machine learning libraries (e.g., PyTorch and TensorFlow). Second, it is compatible with gradient descent, which enables us to integrate GMMs with neural networks seamlessly. Third, it can fit a GMM not only to a set of data points, but also to another GMM directly, without sampling from the target model. And fourth, it has some theoretical guarantees like global gradient boundedness and unbiased sampling gradient. These features are especially useful for distributional reinforcement learning and Deep Q Networks, where the goal is to learn a distribution over future rewards. We will also construct a Gaussian Mixture Distributional Deep Q Network as a toy example to demonstrate its effectiveness. Compared with previous models, this model is parameter efficient in terms of representing a distribution and possesses better interpretability.
翻译:高斯混合模型(简称GMMs)的学习在机器学习中扮演着重要角色。因其强大的表达能力和可解释性,高斯混合模型在统计学、计算机视觉到分布强化学习等领域具有广泛应用。然而,迄今为止,能够拟合或学习这些模型的已知算法寥寥无几,其中包括期望最大化算法和切片Wasserstein距离。更少的算法能够与神经网络中常用的梯度下降学习过程兼容。本文首先推导了一维单变量情形下两个GMM的闭式公式,随后提出了一种名为切片克拉默2-距离(Sliced Cramér 2-distance)的度量函数,用于学习一般多元GMMs。我们的方法相比以往多种方法具有若干优势。其一,它在单变量情形下具有闭式表达式,且易于使用常见机器学习库(如PyTorch和TensorFlow)进行计算和实现。其二,它与梯度下降兼容,使我们能够将GMMs与神经网络无缝集成。其三,它不仅能将GMM拟合到一组数据点,还可直接拟合到另一个GMM,无需从目标模型中采样。其四,它具备某些理论保证,如全局梯度有界性和无偏采样梯度。这些特性对于分布强化学习和深度Q网络尤为有用,其目标是学习未来奖励的分布。我们还将构建一个高斯混合分布深度Q网络作为简易示例,以证明其有效性。与先前模型相比,该模型在表示分布时参数效率更高,并具有更优的可解释性。