We propose that the grokking phenomenon, where the train loss of a neural network decreases much earlier than its test loss, can arise due to a neural network transitioning from lazy training dynamics to a rich, feature learning regime. To illustrate this mechanism, we study the simple setting of vanilla gradient descent on a polynomial regression problem with a two layer neural network which exhibits grokking without regularization in a way that cannot be explained by existing theories. We identify sufficient statistics for the test loss of such a network, and tracking these over training reveals that grokking arises in this setting when the network first attempts to fit a kernel regression solution with its initial features, followed by late-time feature learning where a generalizing solution is identified after train loss is already low. We provide an asymptotic theoretical description of the grokking dynamics in this model using dynamical mean field theory (DMFT) for high dimensional data. We find that the key determinants of grokking are the rate of feature learning -- which can be controlled precisely by parameters that scale the network output -- and the alignment of the initial features with the target function $y(x)$. We argue this delayed generalization arises when (1) the top eigenvectors of the initial neural tangent kernel and the task labels $y(x)$ are misaligned, but (2) the dataset size is large enough so that it is possible for the network to generalize eventually, but not so large that train loss perfectly tracks test loss at all epochs, and (3) the network begins training in the lazy regime so does not learn features immediately. We conclude with evidence that this transition from lazy (linear model) to rich training (feature learning) can control grokking in more general settings, like on MNIST, one-layer Transformers, and student-teacher networks.
翻译:我们提出,“Grokking”现象(即神经网络训练损失远早于测试损失下降的现象)可能源于神经网络从懒惰训练动态向丰富的特征学习机制的转变。为阐明这一机制,我们研究了双层神经网络在多项式回归问题上使用普通梯度下降的简单场景,该场景在无正则化条件下仍表现出“Grokking”现象,且无法被现有理论解释。我们识别了此类网络测试损失的充分统计量,并在训练过程中对其追踪,揭示出该场景下“Grokking”的出现顺序为:网络首先尝试使用初始特征拟合核回归解,随后在训练损失已较低时进入后期特征学习阶段,此时泛化解被识别。我们基于高维数据,利用动态平均场理论(DMFT)为该模型中的“Grokking”动态提供了渐近理论描述。研究发现,“Grokking”的关键决定因素包括:特征学习速率(可通过缩放网络输出的参数精确控制)以及初始特征与目标函数$y(x)$的对齐程度。我们论证了这种延迟泛化现象的出现需满足三个条件:(1) 初始神经正切核的顶部特征向量与任务标签$y(x)$存在错位;(2) 数据集规模足够大,使得网络终将具备泛化能力,但又不至于大到训练损失在所有训练周期均能完美匹配测试损失;(3) 网络从懒惰机制开始训练,因此不会立即学习特征。最后,我们提供证据表明,这种从懒惰训练(线性模型)到丰富训练(特征学习)的转变可在更通用的场景中控制“Grokking”现象,例如在MNIST数据集、单层Transformer以及师生网络中的表现。