We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R} \rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class generalizes the single-index model, which corresponds to $k=1$, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree $k$ polynomials $p$, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target $h$ up to vanishing test error in $\widetilde{\mathcal{O}}(d^k)$ samples and polynomial time. This is a strict improvement over kernel methods, which require $\widetilde \Theta(d^{kq})$ samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of $p$ being a quadratic. When $p$ is indeed a quadratic, we achieve the information-theoretically optimal sample complexity $\widetilde{\mathcal{O}}(d^2)$, which is an improvement over prior work~\citep{nichani2023provable} requiring a sample size of $\widetilde\Theta(d^4)$. Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature $p$ with $\widetilde{\mathcal{O}}(d^k)$ samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.
翻译:我们研究使用三层神经网络在标准高斯分布上学习分层多项式的问题。具体考虑形如 $h = g \circ p$ 的目标函数,其中 $p : \mathbb{R}^d \rightarrow \mathbb{R}$ 是 $k$ 次多项式,$g: \mathbb{R} \rightarrow \mathbb{R}$ 是 $q$ 次多项式。该函数类泛化了对应 $k=1$ 的单指标模型,是一类具有内在分层结构的自然函数类。我们的主要结果表明:对于一大类 $k$ 次多项式 $p$,通过平方损失上逐层梯度下降训练的三层神经网络,能够在 $\widetilde{\mathcal{O}}(d^k)$ 个样本及多项式时间内学习目标函数 $h$ 直至测试误差消失。这严格优于需要 $\widetilde \Theta(d^{kq})$ 个样本的核方法,以及要求目标函数为低秩的现有两层网络保证。该结果还推广了此前三层神经网络的研究工作,这些工作局限于 $p$ 为二次函数的情形。当 $p$ 确实为二次函数时,我们实现了信息论最优的样本复杂度 $\widetilde{\mathcal{O}}(d^2)$,相较于先前要求 $\widetilde\Theta(d^4)$ 样本量的工作~\citep{nichani2023provable} 有所改进。我们的证明思路是:在训练初始阶段,网络通过特征学习以 $\widetilde{\mathcal{O}}(d^k)$ 个样本恢复特征 $p$。该工作展示了三层神经网络学习复杂特征的能力,并因此能够学习一大类分层函数。