Model distillation has been a popular method for producing interpretable machine learning. It uses an interpretable "student" model to mimic the predictions made by the black box "teacher" model. However, when the student model is sensitive to the variability of the data sets used for training even when keeping the teacher fixed, the corresponded interpretation is not reliable. Existing strategies stabilize model distillation by checking whether a large enough corpus of pseudo-data is generated to reliably reproduce student models, but methods to do so have so far been developed for a specific student model. In this paper, we develop a generic approach for stable model distillation based on central limit theorem for the average loss. We start with a collection of candidate student models and search for candidates that reasonably agree with the teacher. Then we construct a multiple testing framework to select a corpus size such that the consistent student model would be selected under different pseudo samples. We demonstrate the application of our proposed approach on three commonly used intelligible models: decision trees, falling rule lists and symbolic regression. Finally, we conduct simulation experiments on Mammographic Mass and Breast Cancer datasets and illustrate the testing procedure throughout a theoretical analysis with Markov process. The code is publicly available at https://github.com/yunzhe-zhou/GenericDistillation.
翻译:模型蒸馏已成为生成可解释机器学习的一种流行方法。它利用可解释的“学生”模型来模仿黑箱“教师”模型所做的预测。然而,当学生模型对用于训练的数据集的可变性敏感时(即使在保持教师模型固定的情况下),相应的解释并不可靠。现有的策略通过检查是否生成了足够大的伪数据语料库来稳定模型蒸馏,以确保学生模型能够被可靠复现,但迄今为止,此类方法仅针对特定的学生模型而开发。本文基于平均损失的中心极限定理,提出了一种用于稳定模型蒸馏的通用方法。我们从一组候选学生模型出发,搜寻与教师模型合理一致的候选模型。随后,我们构建一个多重检验框架,以选择适当的语料库规模,从而确保在不同伪样本下能够选出一致的学生模型。我们在三种常用的可解释模型上展示了所提方法的应用:决策树、衰落规则列表和符号回归。最后,我们在乳腺肿块和乳腺癌数据集上进行了模拟实验,并通过马尔可夫过程的理论分析阐述了检验流程。代码已在 https://github.com/yunzhe-zhou/GenericDistillation 公开。