Despite the remarkable capabilities of modern large language models (LLMs), the mechanisms behind their problem-solving abilities remain elusive. In this work, we aim to better understand how the learning dynamics of LLM finetuning shapes downstream generalization. Our analysis focuses on reasoning tasks, whose problem structure allows us to distinguish between memorization (the exact replication of reasoning steps from the training data) and performance (the correctness of the final solution). We find that a model's generalization behavior can be effectively characterized by a training metric we call pre-memorization train accuracy: the accuracy of model samples on training queries before they begin to copy the exact reasoning steps from the training set. On the dataset level, this metric is able to reliably predict test accuracy, achieving $R^2$ of around or exceeding 0.9 across various models (Llama3 8, Gemma2 9B), datasets (GSM8k, MATH), and training configurations. On a per-example level, this metric is also indicative of whether individual model predictions are robust to perturbations in the training query. By connecting a model's learning behavior to its generalization, pre-memorization train accuracy can guide targeted improvements to training strategies. We focus on data curation as an example, and show that prioritizing examples with low pre-memorization accuracy leads to 1.5-2x improvements in data efficiency compared to i.i.d. data scaling, and outperforms other standard data curation techniques.
翻译:尽管现代大型语言模型(LLM)具备卓越的能力,但其问题解决能力背后的机制仍不明确。本研究旨在深入理解LLM微调的学习动态如何影响下游泛化性能。我们的分析聚焦于推理任务,其问题结构使我们能够区分记忆(对训练数据中推理步骤的精确复现)与性能(最终解决方案的正确性)。我们发现,模型的泛化行为可通过一个称为预记忆训练准确率的训练指标有效表征:该指标衡量模型在开始复制训练集中精确推理步骤之前,对训练查询的样本准确率。在数据集层面,该指标能够可靠地预测测试准确率,在不同模型(Llama3 8B、Gemma2 9B)、数据集(GSM8k、MATH)及训练配置下,其$R^2$值均达到或超过0.9。在单样本层面,该指标也能指示个体模型预测是否对训练查询的扰动具有鲁棒性。通过将模型的学习行为与其泛化能力相关联,预记忆训练准确率可指导针对训练策略的定向改进。我们以数据筛选为例,证明优先选择预记忆准确率较低的样本,相较于独立同分布的数据扩展,可将数据效率提升1.5-2倍,且优于其他标准数据筛选技术。