We study the problem of learning a single neuron with respect to the $L_2^2$-loss in the presence of adversarial distribution shifts, where the labels can be arbitrary, and the goal is to find a ``best-fit'' function. More precisely, given training samples from a reference distribution $\mathcal{p}_0$, the goal is to approximate the vector $\mathbf{w}^*$ which minimizes the squared loss with respect to the worst-case distribution that is close in $\chi^2$-divergence to $\mathcal{p}_{0}$. We design a computationally efficient algorithm that recovers a vector $ \hat{\mathbf{w}}$ satisfying $\mathbb{E}_{\mathcal{p}^*} (\sigma(\hat{\mathbf{w}} \cdot \mathbf{x}) - y)^2 \leq C \, \mathbb{E}_{\mathcal{p}^*} (\sigma(\mathbf{w}^* \cdot \mathbf{x}) - y)^2 + \epsilon$, where $C>1$ is a dimension-independent constant and $(\mathbf{w}^*, \mathcal{p}^*)$ is the witness attaining the min-max risk $\min_{\mathbf{w}~:~\|\mathbf{w}\| \leq W} \max_{\mathcal{p}} \mathbb{E}_{(\mathbf{x}, y) \sim \mathcal{p}} (\sigma(\mathbf{w} \cdot \mathbf{x}) - y)^2 - \nu \chi^2(\mathcal{p}, \mathcal{p}_0)$. Our algorithm follows a primal-dual framework and is designed by directly bounding the risk with respect to the original, nonconvex $L_2^2$ loss. From an optimization standpoint, our work opens new avenues for the design of primal-dual algorithms under structured nonconvexity.
翻译:我们研究在存在对抗性分布偏移的情况下,关于$L_2^2$损失学习单个神经元的问题,其中标签可以是任意的,目标是找到一个“最佳拟合”函数。更精确地说,给定来自参考分布$\mathcal{p}_0$的训练样本,目标是逼近向量$\mathbf{w}^*$,该向量在$\chi^2$散度意义下与$\mathcal{p}_{0}$接近的最坏分布上最小化平方损失。我们设计了一种计算高效的算法,该算法恢复满足$\mathbb{E}_{\mathcal{p}^*} (\sigma(\hat{\mathbf{w}} \cdot \mathbf{x}) - y)^2 \leq C \, \mathbb{E}_{\mathcal{p}^*} (\sigma(\mathbf{w}^* \cdot \mathbf{x}) - y)^2 + \epsilon$的向量$ \hat{\mathbf{w}}$,其中$C>1$是一个与维度无关的常数,$(\mathbf{w}^*, \mathcal{p}^*)$是实现极小极大风险$\min_{\mathbf{w}~:~\|\mathbf{w}\| \leq W} \max_{\mathcal{p}} \mathbb{E}_{(\mathbf{x}, y) \sim \mathcal{p}} (\sigma(\mathbf{w} \cdot \mathbf{x}) - y)^2 - \nu \chi^2(\mathcal{p}, \mathcal{p}_0)$的见证对。我们的算法遵循原始-对偶框架,并通过直接关于原始非凸$L_2^2$损失界定风险而设计。从优化的角度来看,我们的工作为在结构化非凸性下设计原始-对偶算法开辟了新途径。