Modern deep learning models are over-parameterized, where different optima can result in widely varying generalization performance. To account for this, Sharpness-Aware Minimization (SAM) modifies the underlying loss function to guide descent methods towards flatter minima, which arguably have better generalization abilities. In this paper, we focus on a variant of SAM known as micro-batch SAM (mSAM), which, during training, averages the updates generated by adversarial perturbations across several disjoint shards (micro batches) of a mini-batch. We extend a recently developed and well-studied general framework for flatness analysis to show that distributed gradient computation for sharpness-aware minimization theoretically achieves even flatter minima. In order to support this theoretical superiority, we provide a thorough empirical evaluation on a variety of image classification and natural language processing tasks. We also show that contrary to previous work, mSAM can be implemented in a flexible and parallelizable manner without significantly increasing computational costs. Our practical implementation of mSAM yields superior generalization performance across a wide range of tasks compared to SAM, further supporting our theoretical framework.
翻译:摘要:现代深度学习模型存在过参数化现象,不同最优解可能导致泛化性能存在显著差异。为此,锐度感知最小化(SAM)通过修改基础损失函数,引导下降方法趋向于具有更优泛化能力的平坦极小值。本文聚焦SAM的一种变体——微批SAM(mSAM),该方法在训练过程中对来自小批量数据中多个不相交子集(微批次)的对抗扰动所生成的更新进行平均。我们扩展了一个近期发展成熟且经过充分验证的平坦性分析通用框架,证明锐度感知最小化的分布式梯度计算理论上能获得更平坦的极小值。为验证这一理论优越性,我们在多种图像分类和自然语言处理任务上进行了全面的实证评估。同时表明与先前研究相反,mSAM能够以灵活且可并行化的方式实现,且不会显著增加计算成本。相较于SAM,我们提出的mSAM实用化实现方案在广泛任务中展现出更优的泛化性能,进一步支撑了我们的理论框架。