Based on the weight-sharing mechanism, one-shot NAS methods train a supernet and then inherit the pre-trained weights to evaluate sub-models, largely reducing the search cost. However, several works have pointed out that the shared weights suffer from different gradient descent directions during training. And we further find that large gradient variance occurs during supernet training, which degrades the supernet ranking consistency. To mitigate this issue, we propose to explicitly minimize the gradient variance of the supernet training by jointly optimizing the sampling distributions of PAth and DAta (PA&DA). We theoretically derive the relationship between the gradient variance and the sampling distributions, and reveal that the optimal sampling probability is proportional to the normalized gradient norm of path and training data. Hence, we use the normalized gradient norm as the importance indicator for path and training data, and adopt an importance sampling strategy for the supernet training. Our method only requires negligible computation cost for optimizing the sampling distributions of path and data, but achieves lower gradient variance during supernet training and better generalization performance for the supernet, resulting in a more consistent NAS. We conduct comprehensive comparisons with other improved approaches in various search spaces. Results show that our method surpasses others with more reliable ranking performance and higher accuracy of searched architectures, showing the effectiveness of our method. Code is available at https://github.com/ShunLu91/PA-DA.
翻译:基于权重共享机制,单次神经架构搜索方法通过训练超网络并继承预训练权重来评估子模型,从而大幅降低搜索成本。然而,多项研究指出共享权重在训练过程中会遭受不同梯度下降方向的影响。我们进一步发现,超网络训练期间会出现较大的梯度方差,这降低了超网络排序的一致性。为缓解该问题,我们提出通过联合优化路径与数据的采样分布(PA&DA)来显式最小化超网络训练的梯度方差。我们从理论上推导了梯度方差与采样分布之间的关系,并揭示最优采样概率与路径及训练数据的归一化梯度范数成正比。因此,我们将归一化梯度范数作为路径与训练数据的重要性指标,并采用重要性采样策略进行超网络训练。本方法仅需极少的计算开销来优化路径与数据的采样分布,却能实现超网络训练中更低的梯度方差与更好的泛化性能,从而获得更具一致性的神经架构搜索。我们在多种搜索空间中与其他改进方法进行了全面比较。结果表明,我们的方法在排序性能可靠性与搜索架构准确率上均超越现有方法,验证了其有效性。代码开源地址:https://github.com/ShunLu91/PA-DA。