Modern deep learning models are over-parameterized, where different optima can result in widely varying generalization performance. The Sharpness-Aware Minimization (SAM) technique modifies the fundamental loss function that steers gradient descent methods toward flatter minima, which are believed to exhibit enhanced generalization prowess. Our study delves into a specific variant of SAM known as micro-batch SAM (mSAM). This variation involves aggregating updates derived from adversarial perturbations across multiple shards (micro-batches) of a mini-batch during training. We extend a recently developed and well-studied general framework for flatness analysis to theoretically show that SAM achieves flatter minima than SGD, and mSAM achieves even flatter minima than SAM. We provide a thorough empirical evaluation of various image classification and natural language processing tasks to substantiate this theoretical advancement. We also show that contrary to previous work, mSAM can be implemented in a flexible and parallelizable manner without significantly increasing computational costs. Our implementation of mSAM yields superior generalization performance across a wide range of tasks compared to SAM, further supporting our theoretical framework.
翻译:现代深度学习模型往往过度参数化,不同最优点可能导致泛化性能存在显著差异。锐度感知最小化(SAM)技术通过修正基础损失函数,引导梯度下降方法趋向更平坦的极小值,而这类极小值被认为具有更强的泛化能力。本研究深入探讨了SAM的一种特定变体——微批SAM(mSAM)。该变体在训练过程中,将来自具有对抗扰动的多个分片(微批)内的小批量更新进行聚合。我们扩展了一个近期发展且经过充分研究的通用平坦性分析框架,从理论上证明SAM比随机梯度下降(SGD)能获得更平坦的极小值,而mSAM则能获得比SAM更平坦的极小值。通过对多种图像分类和自然语言处理任务的全面实证评估,我们验证了这一理论进展。此外,我们还表明与先前研究相反,mSAM可以以灵活且可并行化的方式实现,而不会显著增加计算成本。相比SAM,我们的mSAM实现在广泛任务中展现出更优的泛化性能,进一步支持了我们的理论框架。