Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.


翻译:神经网络在推理领域的最新进展主要得益于一种训练方法,该方法通过强化学习算法优化先前已训练用于预测序列中下一词的大语言模型。我们引入了一个框架来研究这一范式的成功,并从理论上揭示了在此设置下强化学习如何通过优化机制改进下一词预测。我们研究了从编码单一任务的短、长“思维链”序列混合分布中进行学习。具体而言,当任务涉及预测 $d$ 比特的奇偶性且长序列较为罕见时,我们展示了继下一词预测之后使用强化学习如何使自回归 Transformer 模型实现泛化,而仅使用下一词预测则需要极端的统计或计算资源才能达到相同效果。我们进一步解释了强化学习如何利用测试时增加的计算量(表现为更长的响应)来促进这一学习过程。在一个简化设定中,我们从理论上证明,遵循此训练方法的自回归线性模型能够高效地学习预测 $d$ 比特的奇偶性,只要数据混合中长演示序列的比例相对于输入维度 $d$ 不是指数级小。最后,我们在其他场景中也验证了这些现象,包括在常见数学推理基准的混合变体上对 Llama 系列模型进行后训练。

0
下载
关闭预览

相关内容

FlowQA: Grasping Flow in History for Conversational Machine Comprehension
专知会员服务
34+阅读 · 2019年10月18日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
163+阅读 · 2019年10月12日
Unsupervised Learning via Meta-Learning
CreateAMind
44+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Hierarchical Imitation - Reinforcement Learning
CreateAMind
19+阅读 · 2018年5月25日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
国家自然科学基金
13+阅读 · 2017年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
A Multi-Objective Deep Reinforcement Learning Framework
VIP会员
相关资讯
Unsupervised Learning via Meta-Learning
CreateAMind
44+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Hierarchical Imitation - Reinforcement Learning
CreateAMind
19+阅读 · 2018年5月25日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
相关基金
国家自然科学基金
13+阅读 · 2017年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员