The rapid rise of scientific machine learning (SciML) has expanded the role of differentiable modeling, surrogate modeling, and data-driven constitutive laws in large-scale simulation. The JAX framework provides an attractive environment for these workflows through automatically differentiable programs, vectorization, GPU acceleration, and while enabling seamless learning of surrogate models. However, large-scale simulation still relies on mature HPC infrastructure. Libraries, such as PETSc, provide scalable MPI-based parallelism, robust linear and nonlinear solvers, and advanced preconditioning capabilities that remain difficult to reproduce in JAX-only workflows. We present JetSCI, a hybrid JAX-PETSc framework that unifies these complementary strengths. JetSCI uses JAX for GPU-parallel differentiable discretizations and PETSc for robust, scalable solution of the resulting systems on distributed-memory architectures, exposing multilevel parallelism through GPU acceleration within nodes and MPI parallelism across nodes. For finite element discretizations of heterogeneous micromechanics problems, JetSCI outperforms JAX-only implementations in efficiency and accuracy.
翻译:科学机器学习的迅猛发展拓展了可微建模、代理建模和数据驱动本构关系在大规模仿真中的作用。JAX框架通过自动可微编程、向量化、GPU加速及代理模型的无缝学习能力,为这些工作流提供了极具吸引力的环境。然而,大规模仿真仍依赖于成熟的HPC基础设施。诸如PETSc等库提供了基于MPI的可扩展并行处理、稳健的线性和非线性求解器以及先进的预处理能力,这些在纯JAX工作流中难以复现。我们提出的JetSCI是一种融合两者互补优势的混合JAX-PETSc框架。JetSCI利用JAX实现GPU并行的可微离散化,同时借助PETSc在分布式内存架构上对所得系统进行稳健且可扩展的求解,通过节点内GPU加速与节点间MPI并行实现多层次并行。针对异质微观力学问题的有限元离散化,JetSCI在效率和精度上均优于纯JAX实现。