Pruning schemes have been widely used in practice to reduce the complexity of trained models with a massive number of parameters. In fact, several practical studies have shown that if a pruned model is fine-tuned with some gradient-based updates it generalizes well to new samples. Although the above pipeline, which we refer to as pruning + fine-tuning, has been extremely successful in lowering the complexity of trained models, there is very little known about the theory behind this success. In this paper, we address this issue by investigating the pruning + fine-tuning framework on the overparameterized matrix sensing problem with the ground truth $U_\star \in \mathbb{R}^{d \times r}$ and the overparameterized model $U \in \mathbb{R}^{d \times k}$ with $k \gg r$. We study the approximate local minima of the mean square error, augmented with a smooth version of a group Lasso regularizer, $\sum_{i=1}^k \| U e_i \|_2$. In particular, we provably show that pruning all the columns below a certain explicit $\ell_2$-norm threshold results in a solution $U_{\text{prune}}$ which has the minimum number of columns $r$, yet close to the ground truth in training loss. Moreover, in the subsequent fine-tuning phase, gradient descent initialized at $U_{\text{prune}}$ converges at a linear rate to its limit. While our analysis provides insights into the role of regularization in pruning, we also show that running gradient descent in the absence of regularization results in models which {are not suitable for greedy pruning}, i.e., many columns could have their $\ell_2$ norm comparable to that of the maximum. To the best of our knowledge, our results provide the first rigorous insights on why greedy pruning + fine-tuning leads to smaller models which also generalize well.
翻译:剪枝方法被广泛用于降低具有大量参数的训练模型的复杂度。事实上,多项实践研究表明,若对剪枝后的模型进行基于梯度的微调,其能够很好地泛化至新样本。尽管上述流程(本文称之为“剪枝+微调”)在降低训练模型复杂度方面取得了极大成功,但其背后理论机制尚不明确。本文针对过参数化矩阵感知问题研究该框架,其中真实参数$U_\star \in \mathbb{R}^{d \times r}$,而过参数化模型$U \in \mathbb{R}^{d \times k}$满足$k \gg r$。我们研究了均方误差函数的近似局部最小值,该误差函数附加了光滑化的组套索正则化项$\sum_{i=1}^k \| U e_i \|_2$。特别地,我们可证明地表明:将所有低于特定显式$\ell_2$范数阈值的列剪除后,得到的解$U_{\text{prune}}$具有最小列数$r$,且其训练损失接近真实参数。此外,在后续微调阶段中,以$U_{\text{prune}}$为初始点的梯度下降法以线性速率收敛至其极限。虽然我们的分析揭示了正则化在剪枝中的作用,但我们也发现:在无正则化情况下运行梯度下降法产生的模型不适用于贪婪剪枝——即许多列的$\ell_2$范数可能与最大范数相当。据我们所知,本研究首次从严格理论角度阐明了为何“贪婪剪枝+微调”能产生具有良好泛化性的更小模型。