Many areas of machine learning and science involve large linear algebra problems, such as eigendecompositions, solving linear systems, computing matrix exponentials, and trace estimation. The matrices involved often have Kronecker, convolutional, block diagonal, sum, or product structure. In this paper, we propose a simple but general framework for large-scale linear algebra problems in machine learning, named CoLA (Compositional Linear Algebra). By combining a linear operator abstraction with compositional dispatch rules, CoLA automatically constructs memory and runtime efficient numerical algorithms. Moreover, CoLA provides memory efficient automatic differentiation, low precision computation, and GPU acceleration in both JAX and PyTorch, while also accommodating new objects, operations, and rules in downstream packages via multiple dispatch. CoLA can accelerate many algebraic operations, while making it easy to prototype matrix structures and algorithms, providing an appealing drop-in tool for virtually any computational effort that requires linear algebra. We showcase its efficacy across a broad range of applications, including partial differential equations, Gaussian processes, equivariant model construction, and unsupervised learning.
翻译:机器学习和科学的许多领域涉及大规模线性代数问题,例如特征分解、线性系统求解、矩阵指数计算以及迹估计。所涉及的矩阵通常具有Kronecker、卷积、块对角、求和或乘积结构。在本文中,我们提出了一个简单而通用的机器学习大规模线性代数框架,名为CoLA(组合线性代数)。通过将线性算子抽象与组合调度规则相结合,CoLA能够自动构建内存和运行时高效的数值算法。此外,CoLA在JAX和PyTorch中提供了内存高效的自动微分、低精度计算和GPU加速,同时通过多重调度支持下游包中的新对象、运算和规则。CoLA可以加速许多代数运算,同时便于原型化矩阵结构和算法,为几乎所有需要线性代数的计算任务提供了一个有吸引力的即插即用工具。我们通过包括偏微分方程、高斯过程、等变模型构建和无监督学习在内的广泛应用展示了其有效性。