Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where ($\xi \cdot 100\%$) of labels are corrupted (\emph{i.e.} some results of the modular operations in the training set are incorrect). We show that (i) it is possible for the network to memorize the corrupted labels \emph{and} achieve $100\%$ generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve $100\%$ accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (``mechanistically'') interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes \emph{grokking} dynamics reaching high train \emph{and} test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps from $100\%$ to $100 (1-\xi)\%$.
翻译:稳健泛化是深度学习中的重大挑战,尤其在可训练参数数量极大时。通常很难判断网络是记忆了特定示例集合,还是理解了底层规则(或两者兼具)。受此挑战启发,我们研究了一个可解释模型:在该模型中,泛化表示可通过分析理解,并容易与记忆性表示区分。具体而言,我们考虑在模算术任务上训练的多层感知机(MLP)和Transformer架构,其中($\xi \cdot 100\%$)的标签被损坏(即训练集中部分模运算结果不正确)。我们证明:(i) 网络可能同时记忆损坏标签并实现100%泛化;(ii) 可识别并剪除记忆性神经元,从而降低损坏数据上的准确率并提升未损坏数据上的准确率;(iii) 权重衰减、dropout和批量归一化等正则化方法会迫使网络在优化过程中忽略损坏数据,并在未损坏数据集上达到100%准确率;(iv) 这些正则化方法的效果具有(“机制性”)可解释性:权重衰减和dropout迫使所有神经元学习泛化表示,而批量归一化则减弱记忆性神经元的输出并增强泛化性神经元的输出。最后我们证明,在存在正则化的情况下,训练动态包含两个连续阶段:第一阶段网络经历“顿悟”(grokking)动态,达到高训练准确率与测试准确率;第二阶段它遗忘记忆性表示,此时训练准确率突然从100%跳变至$100 (1-\xi)\%$。