Solving large dense linear systems and eigenvalue problems is a core requirement in many areas of scientific computing, but scaling these operations beyond a single GPU remains challenging within modern programming frameworks. While highly optimized multi-GPU solver libraries exist, they are typically difficult to integrate into composable, just-in-time (JIT) compiled Python workflows. JAXMg provides multi-GPU dense linear algebra for JAX, enabling Cholesky-based linear solves and symmetric eigendecompositions for matrices that exceed single-GPU memory limits. By interfacing JAX with NVIDIA's cuSOLVERMg through an XLA Foreign Function Interface, JAXMg exposes distributed GPU solvers as JIT-compatible JAX primitives. This design allows scalable linear algebra to be embedded directly within JAX programs, preserving composability with JAX transformations and enabling multi-GPU execution in end-to-end scientific workflows.
翻译:求解大规模稠密线性系统和特征值问题是科学计算众多领域的核心需求,但在现代编程框架中将这些运算扩展至单个GPU之外仍具挑战。尽管存在高度优化的多GPU求解器库,它们通常难以集成到可组合、即时编译的Python工作流中。JAXMg为JAX提供了多GPU稠密线性代数功能,支持对超出单GPU内存限制的矩阵进行基于Cholesky分解的线性求解和对称特征分解。通过XLA外部函数接口将JAX与NVIDIA的cuSOLVERMg连接,JAXMg将分布式GPU求解器暴露为JIT兼容的JAX原语。该设计使得可扩展线性代数能够直接嵌入JAX程序中,保持与JAX变换的可组合性,并在端到端科学工作流中实现多GPU执行。