Benchmarks play an important role in the development of machine learning algorithms. For example, research in reinforcement learning (RL) has been heavily influenced by available environments and benchmarks. However, RL environments are traditionally run on the CPU, limiting their scalability with typical academic compute. Recent advancements in JAX have enabled the wider use of hardware acceleration to overcome these computational hurdles, enabling massively parallel RL training pipelines and environments. This is particularly useful for multi-agent reinforcement learning (MARL) research. First of all, multiple agents must be considered at each environment step, adding computational burden, and secondly, the sample complexity is increased due to non-stationarity, decentralised partial observability, or other MARL challenges. In this paper, we present JaxMARL, the first open-source code base that combines ease-of-use with GPU enabled efficiency, and supports a large number of commonly used MARL environments as well as popular baseline algorithms. When considering wall clock time, our experiments show that per-run our JAX-based training pipeline is up to 12500x faster than existing approaches. This enables efficient and thorough evaluations, with the potential to alleviate the evaluation crisis of the field. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. We provide code at https://github.com/flairox/jaxmarl.
翻译:基准测试在机器学习算法的发展中扮演着重要角色。例如,强化学习(RL)的研究就深受现有环境和基准测试的影响。然而,传统上RL环境在CPU上运行,这限制了其在典型学术计算资源下的可扩展性。JAX的最新进展使得硬件加速得到更广泛应用,从而克服了这些计算障碍,实现了大规模并行的RL训练流程和环境。这对多智能体强化学习(MARL)研究尤为有用。首先,每个环境步骤必须考虑多个智能体,增加了计算负担;其次,由于非平稳性、分散的部分可观察性或其它MARL挑战,样本复杂度也随之提高。在本文中,我们提出了JaxMARL——首个结合易用性与GPU高效性,并支持大量常用MARL环境及流行基线算法的开源代码库。从挂钟时间来看,我们的实验表明,每次运行时基于JAX的训练管道比现有方法快达12500倍。这使得高效且彻底的评估成为可能,并有望缓解该领域的评估危机。我们还引入并基准测试了SMAX——流行的星际争霸多智能体挑战(StarCraft Multi-Agent Challenge)的向量化简化版本,它消除了运行星际争霸II游戏引擎的需求。这不仅实现了GPU加速,还提供了更灵活的MARL环境,为MARL中的自我博弈、元学习及其他未来应用释放了潜力。我们的代码可在https://github.com/flairox/jaxmarl获取。