Dataset distillation (DD) has emerged as a widely adopted technique for crafting a synthetic dataset that captures the essential information of a training dataset, facilitating the training of accurate neural models. Its applications span various domains, including transfer learning, federated learning, and neural architecture search. The most popular methods for constructing the synthetic data rely on matching the convergence properties of training the model with the synthetic dataset and the training dataset. However, targeting the training dataset must be thought of as auxiliary in the same sense that the training set is an approximate substitute for the population distribution, and the latter is the data of interest. Yet despite its popularity, an aspect that remains unexplored is the relationship of DD to its generalization, particularly across uncommon subgroups. That is, how can we ensure that a model trained on the synthetic dataset performs well when faced with samples from regions with low population density? Here, the representativeness and coverage of the dataset become salient over the guaranteed training error at inference. Drawing inspiration from distributionally robust optimization, we introduce an algorithm that combines clustering with the minimization of a risk measure on the loss to conduct DD. We provide a theoretical rationale for our approach and demonstrate its effective generalization and robustness across subgroups through numerical experiments. The source code is available in https://github.com/Mming11/RobustDatasetDistillation.
翻译:数据集蒸馏作为一种广泛采用的技术,旨在构建一个能够捕捉训练数据集关键信息的合成数据集,从而促进神经网络模型的精确训练。其应用涵盖迁移学习、联邦学习和神经架构搜索等多个领域。构建合成数据的最流行方法通常依赖于使模型在合成数据集与训练数据集上的训练收敛特性相匹配。然而,目标训练数据集本质上应被视为辅助性的,因为训练集本身只是对群体分布的近似替代,而群体分布才是我们真正关注的数据。尽管如此,数据集蒸馏与泛化能力的关系——尤其是在非典型子群上的表现——仍是一个尚未探索的方面。换言之,如何确保基于合成数据集训练的模型在处理低密度区域样本时仍能表现良好?此时,数据集的代表性和覆盖度相比于推理时的训练误差保证更为突出。受分布鲁棒优化的启发,我们提出一种结合聚类与损失风险度量最小化的算法来进行数据集蒸馏。我们为该算法提供了理论依据,并通过数值实验证明了其在各子群上的有效泛化性与鲁棒性。源代码见 https://github.com/Mming11/RobustDatasetDistillation。