Understanding and modelling the performance of neural architectures is key to Neural Architecture Search (NAS). Performance predictors have seen widespread use in low-cost NAS and achieve high ranking correlations between predicted and ground truth performance in several NAS benchmarks. However, existing predictors are often designed based on network encodings specific to a predefined search space and are therefore not generalizable to other search spaces or new architecture families. In this paper, we propose a general-purpose neural predictor for NAS that can transfer across search spaces, by representing any given candidate Convolutional Neural Network (CNN) with a Computation Graph (CG) that consists of primitive operators. We further combine our CG network representation with Contrastive Learning (CL) and propose a graph representation learning procedure that leverages the structural information of unlabeled architectures from multiple families to train CG embeddings for our performance predictor. Experimental results on NAS-Bench-101, 201 and 301 demonstrate the efficacy of our scheme as we achieve strong positive Spearman Rank Correlation Coefficient (SRCC) on every search space, outperforming several Zero-Cost Proxies, including Synflow and Jacov, which are also generalizable predictors across search spaces. Moreover, when using our proposed general-purpose predictor in an evolutionary neural architecture search algorithm, we can find high-performance architectures on NAS-Bench-101 and find a MobileNetV3 architecture that attains 79.2% top-1 accuracy on ImageNet.
翻译:理解并建模神经架构的性能是神经架构搜索(NAS)的关键。性能预测器在低开销NAS中广泛应用,并在多个NAS基准测试中实现了预测性能与真实性能之间的高排名相关性。然而,现有预测器通常基于特定于预定义搜索空间的网络编码设计,因此无法泛化到其他搜索空间或新型架构家族。本文提出一种面向NAS的通用神经预测器,通过将任意候选卷积神经网络(CNN)表示为包含基本运算符的计算图(CG),实现了跨搜索空间的迁移能力。我们进一步将CG网络表示与对比学习(CL)相结合,提出一种图表示学习流程,利用来自多个家族的无标签架构的结构信息,为性能预测器训练CG嵌入。在NAS-Bench-101、201和301上的实验结果表明,我们的方案在每个搜索空间均取得了强正斯皮尔曼秩相关系数(SRCC),优于Synflow和Jacov等同样可跨搜索空间泛化的零成本代理。此外,当将所提出的通用预测器应用于进化神经架构搜索算法时,我们能在NAS-Bench-101上发现高性能架构,并找到一种在ImageNet上达到79.2% top-1准确率的MobileNetV3架构。