Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods. In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence. However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations. To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism. POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting. We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs. REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours. Our code is available at \url{https://github.com/leor-c/REM}.
翻译:受Transformer在离散符号序列上应用成功的启发,基于令牌的世界模型(TBWM)近期被提出作为样本高效的方法。在TBWM中,世界模型将智能体经验视为类语言令牌序列进行消费,其中每个观测构成一个子序列。然而,在想象过程中,下一观测的逐令牌顺序生成会导致严重的瓶颈,造成训练时间长、GPU利用率低和表征能力受限。为消除此瓶颈,我们设计了一种新颖的并行观测预测(POP)机制。POP通过专为强化学习场景定制的新型前向模式增强了保持网络(RetNet)。我们将POP整合至名为REM(保持环境模型)的新型TBWM智能体中,其想象速度相比先前TBWM提升15.4倍。REM在Atari 100K基准测试的26款游戏中,有12款达到超人类水平,且训练时间不足12小时。代码发布于\url{https://github.com/leor-c/REM}。