Offline reinforcement learning (RL) leverages pre-collected datasets to train optimal policies. Diffusion Q-Learning (DQL), introducing diffusion models as a powerful and expressive policy class, significantly boosts the performance of offline RL. However, its reliance on iterative denoising sampling to generate actions slows down both training and inference. While several recent attempts have tried to accelerate diffusion-QL, the improvement in training and/or inference speed often results in degraded performance. In this paper, we introduce a dual policy approach, Diffusion Trusted Q-Learning (DTQL), which comprises a diffusion policy for pure behavior cloning and a practical one-step policy. We bridge the two polices by a newly introduced diffusion trust region loss. The diffusion policy maintains expressiveness, while the trust region loss directs the one-step policy to explore freely and seek modes within the region defined by the diffusion policy. DTQL eliminates the need for iterative denoising sampling during both training and inference, making it remarkably computationally efficient. We evaluate its effectiveness and algorithmic characteristics against popular Kullback--Leibler divergence-based distillation methods in 2D bandit scenarios and gym tasks. We then show that DTQL could not only outperform other methods on the majority of the D4RL benchmark tasks but also demonstrate efficiency in training and inference speeds. The PyTorch implementation is available at https://github.com/TianyuCodings/Diffusion_Trusted_Q_Learning.
翻译:离线强化学习利用预先收集的数据集来训练最优策略。扩散Q学习引入扩散模型作为强大且富有表达力的策略类别,显著提升了离线强化学习的性能。然而,其依赖迭代去噪采样生成动作的方式降低了训练与推理速度。尽管近期多项研究尝试加速扩散Q学习,但训练和/或推理速度的提升往往以性能下降为代价。本文提出一种双策略方法——扩散信任Q学习,该方法包含一个用于纯行为克隆的扩散策略和一个实用的一步策略。我们通过新提出的扩散信任区域损失将两种策略相连接。扩散策略保持其表达力,而信任区域损失则引导一步策略在扩散策略定义的区域内自由探索并寻找模态。扩散信任Q学习在训练和推理过程中均无需迭代去噪采样,从而实现了显著的计算效率提升。我们在二维赌博机环境和Gym任务中,通过对比流行的基于Kullback-Leibler散度的蒸馏方法,评估了其有效性与算法特性。实验表明,扩散信任Q学习不仅在大多数D4RL基准任务上优于其他方法,同时在训练与推理速度方面也展现出高效性。PyTorch实现代码已发布于https://github.com/TianyuCodings/Diffusion_Trusted_Q_Learning。