Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods. In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence. However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations. To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism. POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting. We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs. REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours. Our code is available at \url{https://github.com/leor-c/REM}.
翻译:受Transformer在离散符号序列上应用成功的启发,基于令牌的世界模型(TBWM)近期被提出作为一种样本高效的方法。在TBWM中,世界模型将智能体经验作为类语言令牌序列进行处理,其中每个观测构成一个子序列。然而,在想象过程中,逐令牌顺序生成下一观测会导致严重的计算瓶颈,造成训练时间长、GPU利用率低和表征能力受限等问题。为解决这一瓶颈,我们设计了一种新颖的并行观测预测(POP)机制。POP通过一种专为强化学习场景定制的新型前向模式增强了保持网络(RetNet)。我们将POP整合至新型TBWM智能体REM(保持环境模型)中,相较于先前的TBWM实现了15.4倍的想象加速。REM在Atari 100K基准测试的26款游戏中,有12款达到超人类性能,且训练时间不足12小时。代码公开于\url{https://github.com/leor-c/REM}。