This paper addresses the challenge of out-of-distribution (OOD) generalization in graph machine learning, a field rapidly advancing yet grappling with the discrepancy between source and target data distributions. Traditional graph learning algorithms, based on the assumption of uniform distribution between training and test data, falter in real-world scenarios where this assumption fails, resulting in suboptimal performance. A principal factor contributing to this suboptimal performance is the inherent simplicity bias of neural networks trained through Stochastic Gradient Descent (SGD), which prefer simpler features over more complex yet equally or more predictive ones. This bias leads to a reliance on spurious correlations, adversely affecting OOD performance in various tasks such as image recognition, natural language understanding, and graph classification. Current methodologies, including subgraph-mixup and information bottleneck approaches, have achieved partial success but struggle to overcome simplicity bias, often reinforcing spurious correlations. To tackle this, we propose DIVE, training a collection of models to focus on all label-predictive subgraphs by encouraging the models to foster divergence on the subgraph mask, which circumvents the limitation of a model solely focusing on the subgraph corresponding to simple structural patterns. Specifically, we employs a regularizer to punish overlap in extracted subgraphs across models, thereby encouraging different models to concentrate on distinct structural patterns. Model selection for robust OOD performance is achieved through validation accuracy. Tested across four datasets from GOOD benchmark and one dataset from DrugOOD benchmark, our approach demonstrates significant improvement over existing methods, effectively addressing the simplicity bias and enhancing generalization in graph machine learning.
翻译:本文针对图机器学习中的分布外泛化挑战展开研究,该领域虽发展迅速,却长期受限于源数据与目标数据分布间的差异。传统图学习算法基于训练与测试数据同分布的假设,在现实场景中该假设失效时表现不佳,导致性能下降。造成此性能缺陷的主要原因是经随机梯度下降训练的神经网络固有的简单性偏好,即倾向于选择简单特征而非复杂但具有同等或更强预测能力的特征。这种偏好导致模型依赖虚假相关性,从而在图像识别、自然语言理解和图分类等任务中损害分布外性能。现有方法(包括子图混合与信息瓶颈方法)虽取得部分成功,但难以克服简单性偏好,往往反而强化了虚假相关性。为此,我们提出DIVE方法,通过训练一组模型使其关注所有与标签预测相关的子图,并鼓励模型在子图掩码上产生分歧,从而避免单个模型仅关注对应简单结构模式的子图。具体而言,我们采用正则化器惩罚不同模型间提取子图的重叠,促使不同模型聚焦于不同的结构模式。通过验证准确率实现面向鲁棒分布外性能的模型选择。在GOOD基准的四个数据集和DrugOOD基准的一个数据集上的测试表明,我们的方法较现有方法有显著提升,有效缓解了简单性偏好并增强了图机器学习的泛化能力。