Gradient descent is one of the most widely used iterative algorithms in modern statistical learning. However, its precise algorithmic dynamics in high-dimensional settings remain only partially understood, which has therefore limited its broader potential for statistical inference applications. This paper provides a precise, non-asymptotic distributional characterization of gradient descent iterates in a broad class of empirical risk minimization problems, in the so-called mean-field regime where the sample size is proportional to the signal dimension. Our non-asymptotic state evolution theory holds for both general non-convex loss functions and non-Gaussian data, and reveals the central role of two Onsager correction matrices that precisely characterize the non-trivial dependence among all gradient descent iterates in the mean-field regime. Although the Onsager correction matrices are typically analytically intractable, our state evolution theory facilitates a generic gradient descent inference algorithm that consistently estimates these matrices across a broad class of models. Leveraging this algorithm, we show that the state evolution can be inverted to construct (i) data-driven estimators for the generalization error of gradient descent iterates and (ii) debiased gradient descent iterates for inference of the unknown signal. Detailed applications to two canonical models--linear regression and (generalized) logistic regression--are worked out to illustrate model-specific features of our general theory and inference methods.
翻译:梯度下降是现代统计学习中最广泛使用的迭代算法之一。然而,其在高维环境下的精确算法动力学仍仅被部分理解,这因此限制了其在统计推断应用中更广泛的潜力。本文在所谓的平均场机制(即样本量与信号维度成比例)下,对一大类经验风险最小化问题中的梯度下降迭代提供了精确的非渐近分布刻画。我们的非渐近状态演化理论适用于一般的非凸损失函数和非高斯数据,并揭示了两个Onsager修正矩阵的核心作用,这两个矩阵精确刻画了平均场机制下所有梯度下降迭代之间的非平凡依赖性。尽管Onsager修正矩阵通常在解析上难以处理,但我们的状态演化理论促进了一种通用的梯度下降推断算法,该算法能够在一大类模型中一致地估计这些矩阵。利用该算法,我们表明状态演化可以被逆用以构建:(i)用于估计梯度下降迭代泛化误差的数据驱动估计器,以及(ii)用于未知信号推断的去偏梯度下降迭代。我们详细推导了线性回归和(广义)逻辑回归这两个典型模型的应用,以阐明我们一般理论及推断方法中模型特定的特征。