In decision-making problems with limited training data, policy functions approximated using deep neural networks often exhibit suboptimal performance. An alternative approach involves learning a world model from the limited data and determining actions through online search. However, the performance is adversely affected by compounding errors arising from inaccuracies in the learnt world model. While methods like TreeQN have attempted to address these inaccuracies by incorporating algorithmic structural biases into their architectures, the biases they introduce are often weak and insufficient for complex decision-making tasks. In this work, we introduce Differentiable Tree Search (DTS), a novel neural network architecture that significantly strengthens the inductive bias by embedding the algorithmic structure of a best-first online search algorithm. DTS employs a learnt world model to conduct a fully differentiable online search in latent state space. The world model is jointly optimised with the search algorithm, enabling the learning of a robust world model and mitigating the effect of model inaccuracies. We address potential Q-function discontinuities arising from naive incorporation of best-first search by adopting a stochastic tree expansion policy, formulating search tree expansion as a decision-making task, and introducing an effective variance reduction technique for the gradient computation. We evaluate DTS in an offline-RL setting with a limited training data scenario on Procgen games and grid navigation task, and demonstrate that DTS outperforms popular model-free and model-based baselines.
翻译:在训练数据有限的决策问题中,采用深度神经网络逼近的策略函数通常表现欠佳。另一种方法是从有限数据中学习世界模型,并通过在线搜索确定动作。然而,学习到的世界模型存在不准确性导致的复合误差会严重影响性能。尽管TreeQN等方法曾试图通过将算法结构偏置融入网络架构来解决这些不准确性,但它们引入的偏置往往较弱,难以应对复杂决策任务。本研究提出可微分树搜索(DTS),这是一种新颖的神经网络架构,通过嵌入最佳优先在线搜索算法的算法结构,显著强化了归纳偏置。DTS利用学习到的世界模型在潜在状态空间中执行完全可微分的在线搜索。世界模型与搜索算法协同优化,既能学习鲁棒的世界模型,又能减轻模型不准确性的影响。针对朴素引入最佳优先搜索可能导致的Q函数不连续问题,我们采用随机树扩展策略,将搜索树扩展建模为决策任务,并引入有效的梯度计算方差缩减技术。我们在离线强化学习场景下,基于有限训练数据条件对Procgen游戏和网格导航任务进行评测,结果表明DTS优于流行的无模型和基于模型的基线方法。