Understanding generalization of overparametrized neural networks remains a fundamental challenge in machine learning. Most of the literature mostly studies generalization from an interpolation point of view, taking convergence of parameters towards a global minimum of the training loss for granted. While overparametrized architectures indeed interpolated the data for typical classification tasks, this interpolation paradigm does not seem valid anymore for more complex tasks such as in-context learning or diffusion. Instead for such tasks, it has been empirically observed that the trained models goes from global minima to spurious local minima of the training loss as the number of training samples becomes larger than some level we call optimization threshold. While the former yields a poor generalization to the true population loss, the latter was observed to actually correspond to the minimiser of this true loss. This paper explores theoretically this phenomenon in the context of two-layer ReLU networks. We demonstrate that, despite overparametrization, networks often converge toward simpler solutions rather than interpolating the training data, which can lead to a drastic improvement on the test loss with respect to interpolating solutions. Our analysis relies on the so called early alignment phase, during which neurons align towards specific directions. This directional alignment, which occurs in the early stage of training, leads to a simplicity bias, wherein the network approximates the ground truth model without converging to the global minimum of the training loss. Our results suggest that this bias, resulting in an optimization threshold from which interpolation is not reached anymore, is beneficial and enhances the generalization of trained models.
翻译:理解过参数化神经网络的泛化能力仍然是机器学习领域的一个基本挑战。现有文献大多从插值的角度研究泛化问题,默认参数会收敛至训练损失的全局最小值。虽然过参数化架构确实能在典型分类任务中实现数据插值,但对于更复杂的任务(如上下文学习或扩散模型),这种插值范式似乎不再成立。相反,在这些任务中,当训练样本数量超过我们称为优化阈值的某个水平时,经验观察到训练模型会从训练损失的全局最小值转向虚假局部最小值。前者对真实总体损失的泛化能力较差,而后者被观察到实际上对应于该真实损失的最小化解。本文在双层ReLU网络的背景下从理论上探讨了这一现象。我们证明,尽管存在过参数化,网络通常会收敛于更简单的解而非插值训练数据,这相对于插值解能在测试损失上带来显著改善。我们的分析依赖于所谓的早期对齐阶段,在此期间神经元会朝向特定方向对齐。这种发生在训练早期的方向对齐会导致简洁性偏置,使网络能够逼近真实模型而无需收敛至训练损失的全局最小值。我们的结果表明,这种导致优化阈值(超过该阈值后不再达到插值)的偏置是有益的,并能提升训练模型的泛化能力。