Machine unlearning -- efficiently removing the effect of a small "forget set" of training data on a pre-trained machine learning model -- has recently attracted significant research interest. Despite this interest, however, recent work shows that existing machine unlearning techniques do not hold up to thorough evaluation in non-convex settings. In this work, we introduce a new machine unlearning technique that exhibits strong empirical performance even in such challenging settings. Our starting point is the perspective that the goal of unlearning is to produce a model whose outputs are statistically indistinguishable from those of a model re-trained on all but the forget set. This perspective naturally suggests a reduction from the unlearning problem to that of data attribution, where the goal is to predict the effect of changing the training set on a model's outputs. Thus motivated, we propose the following meta-algorithm, which we call Datamodel Matching (DMM): given a trained model, we (a) use data attribution to predict the output of the model if it were re-trained on all but the forget set points; then (b) fine-tune the pre-trained model to match these predicted outputs. In a simple convex setting, we show how this approach provably outperforms a variety of iterative unlearning algorithms. Empirically, we use a combination of existing evaluations and a new metric based on the KL-divergence to show that even in non-convex settings, DMM achieves strong unlearning performance relative to existing algorithms. An added benefit of DMM is that it is a meta-algorithm, in the sense that future advances in data attribution translate directly into better unlearning algorithms, pointing to a clear direction for future progress in unlearning.
翻译:机器遗忘——高效消除预训练机器学习模型中少量“遗忘集”训练数据的影响——近来引起了广泛的研究关注。尽管关注度高涨,但近期研究表明,现有机器遗忘技术在非凸场景下无法通过严格评估。本研究提出一种新的机器遗忘技术,即使在如此具有挑战性的场景下仍展现出卓越的实证性能。我们的出发点是基于以下视角:遗忘的目标是生成一个模型,其输出与在排除遗忘集后重新训练的模型输出在统计上不可区分。这一视角自然地将遗忘问题归结为数据归因问题,后者的目标是预测训练集变化对模型输出的影响。基于此动机,我们提出以下元算法(称为数据模型匹配):给定已训练模型,我们(a)利用数据归因技术预测模型在排除遗忘集数据后重新训练时的输出;(b)通过微调预训练模型以匹配这些预测输出。在简单凸场景中,我们证明该方法在理论上优于多种迭代遗忘算法。实证方面,我们结合现有评估方法与基于KL散度的新度量指标,证明即使在非凸场景下,DMM相较于现有算法仍能实现优异的遗忘性能。DMM的附加优势在于其元算法特性:未来数据归因技术的进步可直接转化为更优的遗忘算法,这为机器遗忘领域的未来发展指明了清晰方向。