In decision-making problems with limited training data, policy functions approximated using deep neural networks often exhibit suboptimal performance. An alternative approach involves learning a world model from the limited data and determining actions through online search. However, the performance is adversely affected by compounding errors arising from inaccuracies in the learned world model. While methods like TreeQN have attempted to address these inaccuracies by incorporating algorithmic inductive biases into the neural network architectures, the biases they introduce are often weak and insufficient for complex decision-making tasks. In this work, we introduce Differentiable Tree Search Network (D-TSN), a novel neural network architecture that significantly strengthens the inductive bias by embedding the algorithmic structure of a best-first online search algorithm. D-TSN employs a learned world model to conduct a fully differentiable online search. The world model is jointly optimized with the search algorithm, enabling the learning of a robust world model and mitigating the effect of prediction inaccuracies. Further, we note that a naive incorporation of best-first search could lead to a discontinuous loss function in the parameter space. We address this issue by adopting a stochastic tree expansion policy, formulating search tree expansion as another decision-making task, and introducing an effective variance reduction technique for the gradient computation. We evaluate D-TSN in an offline-RL setting with a limited training data scenario on Procgen games and grid navigation task, and demonstrate that D-TSN outperforms popular model-free and model-based baselines.
翻译:在训练数据有限的决策问题中,使用深度神经网络近似的策略函数通常表现出次优性能。一种替代方法是从有限数据中学习世界模型,并通过在线搜索确定动作。然而,学习的世界模型中的不准确性会产生复合误差,从而对性能产生不利影响。尽管像TreeQN这样的方法试图通过将算法归纳偏置融入神经网络架构来解决这些不准确性问题,但它们引入的偏置通常较弱,不足以应对复杂的决策任务。在这项工作中,我们提出了可微分树搜索网络(D-TSN),这是一种新颖的神经网络架构,它通过嵌入最佳优先在线搜索算法的算法结构,显著增强了归纳偏置。D-TSN使用学习的世界模型进行完全可微分的在线搜索。世界模型与搜索算法联合优化,从而能够学习到鲁棒的世界模型并减轻预测不准确性的影响。此外,我们注意到,简单地融入最佳优先搜索可能导致参数空间中的损失函数不连续。我们通过采用随机树扩展策略来解决这个问题,将搜索树扩展公式化为另一个决策任务,并为梯度计算引入了一种有效的方差缩减技术。我们在Procgen游戏和网格导航任务的有限训练数据场景下的离线强化学习设置中评估了D-TSN,并证明D-TSN优于流行的无模型和基于模型的基线方法。