Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.
翻译:现代机器学习应用见证了设计用于寻找平坦极小值的优化算法的显著成功。受这一设计选择的启发,我们开展了一项形式化研究,该研究(i) 阐述了平坦极小值的概念,并(ii) 探讨了寻找这些极小值的复杂度。具体而言,我们采用代价函数海森矩阵的迹作为平坦性的度量,并据此正式定义了近似平坦极小值的概念。在此概念下,我们随后分析了能够高效找到近似平坦极小值的算法。针对一般代价函数,我们讨论了一种基于梯度的算法,该算法能够高效地找到近似平坦的局部极小值。该算法的主要组成部分是利用从随机扰动迭代点计算出的梯度来估计一个能导向更平坦极小值的方向。对于代价函数为训练数据上经验风险的情形,我们提出了一种更快的算法,该算法受近期提出的名为“锐度感知最小化”的实用算法启发,从而支撑了其在实际应用中的成功。