We study the problem of learning a single neuron under standard squared loss in the presence of arbitrary label noise and group-level distributional shifts, for a broad family of covariate distributions. Our goal is to identify a ''best-fit'' neuron parameterized by $\mathbf{w}_*$ that performs well under the most challenging reweighting of the groups. Specifically, we address a Group Distributionally Robust Optimization problem: given sample access to $K$ distinct distributions $\mathcal p_{[1]},\dots,\mathcal p_{[K]}$, we seek to approximate $\mathbf{w}_*$ that minimizes the worst-case objective over convex combinations of group distributions $\boldsymbolλ \in Δ_K$, where the objective is $\sum_{i \in [K]}λ_{[i]}\,\mathbb E_{(\mathbf x,y)\sim\mathcal p_{[i]}}(σ(\mathbf w\cdot\mathbf x)-y)^2 - νd_f(\boldsymbolλ,\frac{1}{K}\mathbf1)$ and $d_f$ is an $f$-divergence that imposes (optional) penalty on deviations from uniform group weights, scaled by a parameter $ν\geq 0$. We develop a computationally efficient primal-dual algorithm that outputs a vector $\widehat{\mathbf w}$ that is constant-factor competitive with $\mathbf{w}_*$ under the worst-case group weighting. Our analytical framework directly confronts the inherent nonconvexity of the loss function, providing robust learning guarantees in the face of arbitrary label corruptions and group-specific distributional shifts. The implementation of the dual extrapolation update motivated by our algorithmic framework shows promise on LLM pre-training benchmarks.
翻译:研究在存在任意标签噪声和群体级分布偏移的情况下,针对广泛协变量分布族,学习标准平方损失下单个神经元的问题。我们的目标是识别一个由$\mathbf{w}_*$参数化的“最佳适配”神经元,使其在最具挑战性的群体重加权下表现良好。具体而言,我们解决了一个群体分布鲁棒优化问题:给定对$K$个不同分布$\mathcal p_{[1]},\dots,\mathcal p_{[K]}$的样本访问权限,我们寻求近似$\mathbf{w}_*$,该向量最小化群体分布凸组合$\boldsymbolλ \in Δ_K$上的最坏情况目标,其中目标函数为$\sum_{i \in [K]}λ_{[i]}\,\mathbb E_{(\mathbf x,y)\sim\mathcal p_{[i]}}(σ(\mathbf w\cdot\mathbf x)-y)^2 - νd_f(\boldsymbolλ,\frac{1}{K}\mathbf1)$,而$d_f$是一个$f$-散度,用于对均匀群体权重的偏离施加(可选)惩罚,其惩罚程度由参数$ν\geq 0$缩放。我们开发了一种计算高效的原始-对偶算法,该算法输出向量$\widehat{\mathbf w}$,在最坏情况群体加权下与$\mathbf{w}_*$具有常数因子竞争力。我们的分析框架直接应对损失函数固有的非凸性,在存在任意标签破坏和群体特定分布偏移的情况下,提供了鲁棒学习保证。由我们的算法框架驱动的对偶外推更新实现,在LLM预训练基准测试中显示出前景。