Many areas of machine learning and science involve large linear algebra problems, such as eigendecompositions, solving linear systems, computing matrix exponentials, and trace estimation. The matrices involved often have Kronecker, convolutional, block diagonal, sum, or product structure. In this paper, we propose a simple but general framework for large-scale linear algebra problems in machine learning, named CoLA (Compositional Linear Algebra). By combining a linear operator abstraction with compositional dispatch rules, CoLA automatically constructs memory and runtime efficient numerical algorithms. Moreover, CoLA provides memory efficient automatic differentiation, low precision computation, and GPU acceleration in both JAX and PyTorch, while also accommodating new objects, operations, and rules in downstream packages via multiple dispatch. CoLA can accelerate many algebraic operations, while making it easy to prototype matrix structures and algorithms, providing an appealing drop-in tool for virtually any computational effort that requires linear algebra. We showcase its efficacy across a broad range of applications, including partial differential equations, Gaussian processes, equivariant model construction, and unsupervised learning.
翻译:机器学习和科学的许多领域涉及大规模线性代数问题,例如特征分解、求解线性系统、计算矩阵指数和迹估计。所涉及的矩阵通常具有Kronecker、卷积、块对角、求和或乘积结构。在本文中,我们针对机器学习中的大规模线性代数问题提出了一个简单但通用的框架,命名为CoLA(组合线性代数)。通过将线性算子抽象与组合分派规则相结合,CoLA自动构建内存和运行时高效的数值算法。此外,CoLA在JAX和PyTorch中均提供内存高效的自动微分、低精度计算和GPU加速,同时通过多重分派支持下游包中的新对象、操作和规则。CoLA可以加速许多代数运算,同时简化矩阵结构和算法的原型设计,为几乎所有需要线性代数的计算工作提供了极具吸引力的即插即用工具。我们展示了它在广泛的应用中的有效性,包括偏微分方程、高斯过程、等变模型构建和无监督学习。