This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) \!=\! \sum_{j=1}^k \! \sigma^*(v_j^T x)$ where $v_1, \dots, v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k\!=\!1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \! \asymp \! k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $ v_i^T v_j \!=\! \beta \! \geq \! 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c \!=\! c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
翻译:本研究聚焦于使用相关性损失在高维标准高斯数据上逼近多索引函数的神经网络模型的梯度流动力学。具体而言,我们考虑的多索引函数为神经元之和 $f^*(x) \!=\! \sum_{j=1}^k \! \sigma^*(v_j^T x)$,其中 $v_1, \dots, v_k$ 为单位向量,且 $\sigma^*$ 在其埃尔米特展开式中缺少前两项埃尔米特多项式。已知在单索引情形($k\!=\!1$)下,克服搜索阶段需要多项式时间复杂度。我们首先将该结果推广至由任意方向向量表征的多索引函数。在搜索阶段之后,网络神经元是否收敛至索引向量,或陷入次优解尚不明确。当索引向量正交时,我们完整刻画了不动点特性并证明神经元会收敛至最近的索引向量。因此,使用 $n \! \asymp \! k \log k$ 个神经元能确保通过梯度流以高概率(在随机初始化条件下)找到完整的索引向量集合。当所有 $i \neq j$ 满足 $ v_i^T v_j \!=\! \beta \! \geq \! 0$ 时,我们证明存在一个尖锐阈值 $\beta_c \!=\! c/(c+k)$,当超过该阈值时,计算索引向量平均值的固定点会从鞍点转变为极小值点。数值模拟表明,当索引向量接近正交时,使用相关性损失与适度过参数化足以学习所有索引向量;然而当索引向量间的点积超过特定阈值时,相关性损失将失效。