We present CosmoPower-JAX, a JAX-based implementation of the CosmoPower framework, which accelerates cosmological inference by building neural emulators of cosmological power spectra. We show how, using the automatic differentiation, batch evaluation and just-in-time compilation features of JAX, and running the inference pipeline on graphics processing units (GPUs), parameter estimation can be accelerated by orders of magnitude with advanced gradient-based sampling techniques. These can be used to efficiently explore high-dimensional parameter spaces, such as those needed for the analysis of next-generation cosmological surveys. We showcase the accuracy and computational efficiency of CosmoPower-JAX on two simulated Stage IV configurations. We first consider a single survey performing a cosmic shear analysis totalling 37 model parameters. We validate the contours derived with CosmoPower-JAX and a Hamiltonian Monte Carlo sampler against those derived with a nested sampler and without emulators, obtaining a speed-up factor of $\mathcal{O}(10^3)$. We then consider a combination of three Stage IV surveys, each performing a joint cosmic shear and galaxy clustering (3x2pt) analysis, for a total of 157 model parameters. Even with such a high-dimensional parameter space, CosmoPower-JAX provides converged posterior contours in 3 days, as opposed to the estimated 6 years required by standard methods. CosmoPower-JAX is fully written in Python, and we make it publicly available to help the cosmological community meet the accuracy requirements set by next-generation surveys.
翻译:我们提出CosmoPower-JAX,一种基于JAX的CosmoPower框架实现,通过构建宇宙学功率谱的神经仿真器来加速宇宙学推断。我们展示了如何利用JAX的自动微分、批量评估和即时编译特性,并结合在图形处理单元(GPU)上运行推断流程,借助先进的基于梯度的采样技术将参数估计加速数个数量级。这些技术可用于高效探索高维参数空间,例如分析下一代宇宙学巡天所需的空间。我们通过两个模拟的第四代(Stage IV)配置验证了CosmoPower-JAX的精度与计算效率。首先考虑单一巡天进行弱引力剪切分析,共包含37个模型参数。我们验证了使用CosmoPower-JAX与Hamiltonian Monte Carlo采样器导出的等高线,与使用嵌套采样器且无仿真器时导出的结果一致,加速因子达到$\mathcal{O}(10^3)$。随后考虑三个Stage IV巡天的组合,每个巡天进行联合弱引力剪切与星系成团性(3x2pt)分析,共涉及157个模型参数。即便在此高维参数空间中,CosmoPower-JAX在3天内即可提供收敛的后验等高线,而标准方法预计需要6年。CosmoPower-JAX完全用Python编写,我们将其公开提供,以帮助宇宙学界满足下一代巡天设置的精度要求。