Large Language Models (LLMs) have exhibited remarkable performance on reasoning tasks. They utilize autoregressive token generation to construct reasoning trajectories, enabling the development of a coherent chain of thought. In this work, we explore the impact of individual tokens on the final outcomes of reasoning tasks. We identify the existence of ``critical tokens'' that lead to incorrect reasoning trajectories in LLMs. Specifically, we find that LLMs tend to produce positive outcomes when forced to decode other tokens instead of critical tokens. Motivated by this observation, we propose a novel approach - cDPO - designed to automatically recognize and conduct token-level rewards for the critical tokens during the alignment process. Specifically, we develop a contrastive estimation approach to automatically identify critical tokens. It is achieved by comparing the generation likelihood of positive and negative models. To achieve this, we separately fine-tune the positive and negative models on various reasoning trajectories, consequently, they are capable of identifying identify critical tokens within incorrect trajectories that contribute to erroneous outcomes. Moreover, to further align the model with the critical token information during the alignment process, we extend the conventional DPO algorithms to token-level DPO and utilize the differential likelihood from the aforementioned positive and negative model as important weight for token-level DPO learning.Experimental results on GSM8K and MATH500 benchmarks with two-widely used models Llama-3 (8B and 70B) and deepseek-math (7B) demonstrate the effectiveness of the propsoed approach cDPO.
翻译:大语言模型在推理任务中展现出卓越性能。它们通过自回归令牌生成构建推理轨迹,从而形成连贯的思维链。本研究探讨了单个令牌对推理任务最终结果的影响。我们识别出导致大语言模型产生错误推理轨迹的"关键令牌"。具体而言,我们发现当强制模型解码其他令牌而非关键令牌时,大语言模型倾向于产生正向结果。基于此观察,我们提出一种新方法——cDPO——旨在对齐过程中自动识别关键令牌并实施令牌级奖励。具体来说,我们开发了一种对比估计方法来自动识别关键令牌,该方法通过比较正向模型与负向模型的生成似然实现。为此,我们在不同推理轨迹上分别微调正向与负向模型,使其能够识别错误轨迹中导致错误结果的关键令牌。此外,为在对齐过程中进一步利用关键令牌信息调整模型,我们将传统DPO算法扩展至令牌级DPO,并利用前述正向与负向模型的似然差分作为令牌级DPO学习的重要权重。在GSM8K和MATH500基准测试中,使用Llama-3(8B和70B)与deepseek-math(7B)两种广泛使用模型的实验结果表明,所提出的cDPO方法具有显著有效性。