Model substructure learning aims to find an invariant network substructure that can have better out-of-distribution (OOD) generalization than the original full structure. Existing works usually search the invariant substructure using modular risk minimization (MRM) with fully exposed out-domain data, which may bring about two drawbacks: 1) Unfairness, due to the dependence of the full exposure of out-domain data; and 2) Sub-optimal OOD generalization, due to the equally feature-untargeted pruning on the whole data distribution. Based on the idea that in-distribution (ID) data with spurious features may have a lower experience risk, in this paper, we propose a novel Spurious Feature-targeted model Pruning framework, dubbed SFP, to automatically explore invariant substructures without referring to the above drawbacks. Specifically, SFP identifies spurious features within ID instances during training using our theoretically verified task loss, upon which, SFP attenuates the corresponding feature projections in model space to achieve the so-called spurious feature-targeted pruning. This is typically done by removing network branches with strong dependencies on identified spurious features, thus SFP can push the model learning toward invariant features and pull that out of spurious features and devise optimal OOD generalization. Moreover, we also conduct detailed theoretical analysis to provide the rationality guarantee and a proof framework for OOD structures via model sparsity, and for the first time, reveal how a highly biased data distribution affects the model's OOD generalization. Experiments on various OOD datasets show that SFP can significantly outperform both structure-based and non-structure-based OOD generalization SOTAs, with accuracy improvement up to 4.72% and 23.35%, respectively
翻译:模型子结构学习旨在发现一种不变网络子结构,使其在分布外泛化方面优于原始完整结构。现有方法通常使用模块化风险最小化(MRM)并依赖完全暴露的域外数据来搜索不变子结构,这可能导致两个缺陷:1)不公平性,由于对域外数据完全暴露的依赖;2)次优的分布外泛化,由于对整个数据分布进行无差别特征裁剪。基于包含伪特征的分布内数据可能具有较低经验风险的思路,本文提出一种名为SFP的新型伪特征定向剪枝框架,以自动探索不变子结构而无需依赖上述缺陷。具体而言,SFP在训练过程中利用我们经理论验证的任务损失函数识别分布内实例中的伪特征,并据此在模型空间中削弱相应特征投影实现伪特征定向剪枝。这一过程通常通过移除与已识别伪特征强依赖关系的网络分支实现,从而推动模型学习向不变特征收敛、远离伪特征,并设计出最优的分布外泛化方案。此外,我们进行了详细的理论分析,通过模型稀疏性为分布外结构提供合理性保证与证明框架,首次揭示了高度偏态数据分布对模型分布外泛化的影响机制。在多种分布外数据集上的实验表明,SFP在结构方法与非结构方法中均显著超越现有最优分布外泛化方法,准确率分别提升高达4.72%与23.35%。