Machine-learning models are prone to capturing the spurious correlations between non-causal attributes and classes, with counterfactual data augmentation being a promising direction for breaking these spurious associations. However, explicitly generating counterfactual data is challenging, with the training efficiency declining. Therefore, this study proposes an implicit counterfactual data augmentation (ICDA) method to remove spurious correlations and make stable predictions. Specifically, first, a novel sample-wise augmentation strategy is developed that generates semantically and counterfactually meaningful deep features with distinct augmentation strength for each sample. Second, we derive an easy-to-compute surrogate loss on the augmented feature set when the number of augmented samples becomes infinite. Third, two concrete schemes are proposed, including direct quantification and meta-learning, to derive the key parameters for the robust loss. In addition, ICDA is explained from a regularization aspect, with extensive experiments indicating that our method consistently improves the generalization performance of popular depth networks on multiple typical learning scenarios that require out-of-distribution generalization.
翻译:机器学习模型容易捕获非因果属性与类别之间的虚假相关性,而反事实数据增强是打破此类虚假关联的有效途径。然而,显式生成反事实数据面临挑战,且会导致训练效率下降。为此,本研究提出一种隐式反事实数据增强(ICDA)方法,以消除虚假相关性并实现稳定预测。具体而言:首先,开发了一种新颖的逐样本增强策略,可为每个样本生成具有不同增强强度的语义及反事实含义的深度特征;其次,在增强样本数量趋于无穷时,推导出易计算的代理损失函数;第三,提出两种具体方案(直接量化法与元学习法)来获取鲁棒损失的关键参数。此外,从正则化角度对ICDA进行理论阐释,大量实验表明,在需要分布外泛化的多个典型学习场景中,本方法能持续提升主流深度网络的泛化性能。