Distillation-aware Neural Architecture Search (DaNAS) aims to search for an optimal student architecture that obtains the best performance and/or efficiency when distilling the knowledge from a given teacher model. Previous DaNAS methods have mostly tackled the search for the neural architecture for fixed datasets and the teacher, which are not generalized well on a new task consisting of an unseen dataset and an unseen teacher, thus need to perform a costly search for any new combination of the datasets and the teachers. For standard NAS tasks without KD, meta-learning-based computationally efficient NAS methods have been proposed, which learn the generalized search process over multiple tasks (datasets) and transfer the knowledge obtained over those tasks to a new task. However, since they assume learning from scratch without KD from a teacher, they might not be ideal for DaNAS scenarios. To eliminate the excessive computational cost of DaNAS methods and the sub-optimality of rapid NAS methods, we propose a distillation-aware meta accuracy prediction model, DaSS (Distillation-aware Student Search), which can predict a given architecture's final performances on a dataset when performing KD with a given teacher, without having actually to train it on the target task. The experimental results demonstrate that our proposed meta-prediction model successfully generalizes to multiple unseen datasets for DaNAS tasks, largely outperforming existing meta-NAS methods and rapid NAS baselines. Code is available at https://github.com/CownowAn/DaSS
翻译:蒸馏感知神经架构搜索(DaNAS)旨在搜索最优的学生架构,以便在从给定教师模型蒸馏知识时获得最佳性能和/或效率。以往的DaNAS方法主要针对固定数据集和教师进行架构搜索,这些方法难以泛化到包含未见数据集和未见教师的新任务,因此需要针对任何新的数据集与教师组合执行代价高昂的搜索。针对无知识蒸馏(KD)的标准NAS任务,已有基于元学习的计算高效NAS方法提出,这些方法学习跨多个任务(数据集)的泛化搜索过程,并将从这些任务中获得的知识迁移到新任务上。然而,由于它们假设从零开始学习而不涉及教师知识蒸馏,这些方法可能不适用于DaNAS场景。为消除DaNAS方法的过高计算成本以及快速NAS方法的次优性,我们提出了一种蒸馏感知的元准确率预测模型——DaSS(蒸馏感知学生搜索),该模型能够在给定教师的知识蒸馏过程中预测某一架构在数据集上的最终性能,而无需实际在目标任务上进行训练。实验结果表明,我们所提出的元预测模型成功泛化到多个未见数据集上的DaNAS任务,显著优于现有元NAS方法和快速NAS基线。代码已开源:https://github.com/CownowAn/DaSS