Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often demands complex and deep architectures, which are expensive to compute and train. Within the world model, dynamics models are particularly crucial for accurate predictions, and various dynamics-model architectures have been explored, each with its own set of challenges. Currently, recurrent neural network (RNN) based world models face issues such as vanishing gradients and difficulty in capturing long-term dependencies effectively. In contrast, use of transformers suffers from the well-known issues of self-attention mechanisms, where both memory and computational complexity scale as $O(n^2)$, with $n$ representing the sequence length. To address these challenges we propose a state space model (SSM) based world model, specifically based on Mamba, that achieves $O(n)$ memory and computational complexity while effectively capturing long-term dependencies and facilitating the use of longer training sequences efficiently. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early stages of training, combining it with the aforementioned technique to achieve a normalised score comparable to other state-of-the-art model-based RL algorithms using only a 7 million trainable parameter world model. This model is accessible and can be trained on an off-the-shelf laptop. Our code is available at https://github.com/realwenlongwang/drama.git.
翻译:基于模型的强化学习(RL)为解决困扰大多数无模型RL算法的数据低效问题提供了一种方案。然而,学习一个鲁棒的世界模型通常需要复杂且深层的架构,这在计算和训练上成本高昂。在世界模型中,动态模型对于准确预测尤为关键,目前已探索了多种动态模型架构,每种架构都面临其特有的挑战。当前,基于循环神经网络(RNN)的世界模型存在梯度消失和难以有效捕捉长期依赖等问题。相比之下,使用Transformer则受到自注意力机制固有问题的困扰,其内存和计算复杂度均按$O(n^2)$缩放,其中$n$表示序列长度。为应对这些挑战,我们提出了一种基于状态空间模型(SSM)的世界模型,特别是基于Mamba架构,该模型在有效捕捉长期依赖并高效利用更长训练序列的同时,实现了$O(n)$的内存和计算复杂度。我们还引入了一种新颖的采样方法,以缓解训练早期因世界模型不准确导致的次优性问题,将该方法与前述技术结合后,仅使用一个可训练参数为700万的世界模型,即可获得与其他最先进的基于模型RL算法相当的归一化得分。该模型易于访问,并可在现成的笔记本电脑上进行训练。我们的代码发布于https://github.com/realwenlongwang/drama.git。