Machine learning algorithms have shown remarkable performance in diverse applications. However, it is still challenging to guarantee performance in distribution shifts when distributions of training and test datasets are different. There have been several approaches to improve the performance in distribution shift cases by learning invariant features across groups or domains. However, we observe that the previous works only learn invariant features partially. While the prior works focus on the limited invariant features, we first raise the importance of the sufficient invariant features. Since only training sets are given empirically, the learned partial invariant features from training sets might not be present in the test sets under distribution shift. Therefore, the performance improvement on distribution shifts might be limited. In this paper, we argue that learning sufficient invariant features from the training set is crucial for the distribution shift case. Concretely, we newly observe the connection between a) sufficient invariant features and b) flatness differences between groups or domains. Moreover, we propose a new algorithm, Adaptive Sharpness-aware Group Distributionally Robust Optimization (ASGDRO), to learn sufficient invariant features across domains or groups. ASGDRO learns sufficient invariant features by seeking common flat minima across all groups or domains. Therefore, ASGDRO improves the performance on diverse distribution shift cases. Besides, we provide a new simple dataset, Heterogeneous-CMNIST, to diagnose whether the various algorithms learn sufficient invariant features.
翻译:机器学习算法已在多种应用中展现出卓越性能。然而,当训练集与测试集分布存在差异时,如何保证分布漂移场景下的性能仍是一大挑战。现有研究通过跨群体或跨域学习不变特征来提升分布漂移下的表现。但本文发现,先前工作仅能学习到部分不变特征。鉴于现有研究聚焦于有限的不变特征,我们首次提出充分不变特征的重要性。由于实际中仅能获取训练集,训练集学得的局部不变特征在分布漂移下可能不适用于测试集,因而对分布漂移的性能提升存在局限。本文指出,从训练集中学习充分不变特征对应对分布漂移至关重要。具体而言,我们首次观察到a)充分不变特征与b)群体或域间平坦度差异之间的关联。据此提出新算法——自适应锐度感知群体分布鲁棒优化(ASGDRO),用于学习跨领域或群体的充分不变特征。ASGDRO通过寻找所有群体或域的公共平坦极小值来学习充分不变特征,从而在多种分布漂移场景中提升性能。此外,我们构建了新型简单数据集Heterogeneous-CMNIST,用于诊断各类算法是否学习到充分不变特征。