Benchmarks play an important role in the development of machine learning algorithms. For example, research in reinforcement learning (RL) has been heavily influenced by available environments and benchmarks. However, RL environments are traditionally run on the CPU, limiting their scalability with typical academic compute. Recent advancements in JAX have enabled the wider use of hardware acceleration to overcome these computational hurdles, enabling massively parallel RL training pipelines and environments. This is particularly useful for multi-agent reinforcement learning (MARL) research. First of all, multiple agents must be considered at each environment step, adding computational burden, and secondly, the sample complexity is increased due to non-stationarity, decentralised partial observability, or other MARL challenges. In this paper, we present JaxMARL, the first open-source code base that combines ease-of-use with GPU enabled efficiency, and supports a large number of commonly used MARL environments as well as popular baseline algorithms. When considering wall clock time, our experiments show that per-run our JAX-based training pipeline is up to 12500x faster than existing approaches. This enables efficient and thorough evaluations, with the potential to alleviate the evaluation crisis of the field. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. We provide code at https://github.com/flairox/jaxmarl.
翻译:基准测试在机器学习算法的发展中扮演着重要角色。例如,强化学习领域的研究一直深受现有环境和基准测试的影响。然而,传统上强化学习环境在CPU上运行,受限于典型学术计算资源,其可扩展性不足。JAX的最新进展推动了硬件加速技术的广泛应用,从而克服了这些计算障碍,使得大规模并行强化学习训练流程和环境成为可能。这对多智能体强化学习研究尤为有益。首先,每个环境步骤中需考虑多个智能体,这增加了计算负担;其次,由于非平稳性、分散式部分可观测性及其他MARL挑战,样本复杂度随之提升。本文提出JaxMARL——首个将易用性与GPU高效性相结合的开源代码库,支持大量常用MARL环境及主流基线算法。实验表明,在考虑挂钟时间时,我们基于JAX的训练流程单次运行速度比现有方法快达12500倍。这实现了高效且全面的评估,有望缓解该领域的评估危机。我们还引入并基准测试了SMAX——流行基准环境StarCraft Multi-Agent Challenge的矢量化简化版本,它无需运行星际争霸II游戏引擎。这不仅实现了GPU加速,还提供了更灵活的MARL环境,为自博弈、元学习及其他MARL未来应用释放了潜力。代码已开源至https://github.com/flairox/jaxmarl。