Solving large dense linear systems and eigenvalue problems is a core requirement in many areas of scientific computing, but scaling these operations beyond a single GPU remains challenging within modern programming frameworks. While highly optimized multi-GPU solver libraries exist, they are typically difficult to integrate into composable, just-in-time (JIT) compiled Python workflows. JAXMg provides multi-GPU dense linear algebra for JAX, enabling Cholesky-based linear solves and symmetric eigendecompositions for matrices that exceed single-GPU memory limits. By interfacing JAX with NVIDIA's cuSOLVERMg through an XLA Foreign Function Interface, JAXMg exposes distributed GPU solvers as JIT-compatible JAX primitives. This design allows scalable linear algebra to be embedded directly within JAX programs, preserving composability with JAX transformations and enabling multi-GPU execution in end-to-end scientific workflows.


翻译:求解大规模稠密线性系统和特征值问题是科学计算众多领域的核心需求,但在现代编程框架中将这些运算扩展至单个GPU之外仍具挑战。尽管存在高度优化的多GPU求解器库,它们通常难以集成到可组合、即时编译的Python工作流中。JAXMg为JAX提供了多GPU稠密线性代数功能,支持对超出单GPU内存限制的矩阵进行基于Cholesky分解的线性求解和对称特征分解。通过XLA外部函数接口将JAX与NVIDIA的cuSOLVERMg连接,JAXMg将分布式GPU求解器暴露为JIT兼容的JAX原语。该设计使得可扩展线性代数能够直接嵌入JAX程序中,保持与JAX变换的可组合性,并在端到端科学工作流中实现多GPU执行。

0
下载
关闭预览

相关内容

【ICML2024】MVMoE:具有专家混合的多任务车辆路径求解器
PyTorch实现多种深度强化学习算法
专知
36+阅读 · 2019年1月15日
推荐|TensorFlow/PyTorch/Sklearn实现的五十种机器学习模型
全球人工智能
24+阅读 · 2017年7月14日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
VIP会员
相关VIP内容
【ICML2024】MVMoE:具有专家混合的多任务车辆路径求解器
相关基金
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员