Causal discovery from observational data remains challenging due to the need to recover directed structure and latent confounding without interventions. We propose FoundCause, an amortized causal discovery model trained entirely on synthetic data that maps datasets directly to causal graphs in a single forward pass. By learning from large collections of simulated structural causal models, FoundCause captures transferable statistical patterns that generalize beyond individual datasets. The architecture incorporates several key inductive biases for causal discovery. It uses a permutation-invariant transformer encoder with alternating attention over samples and variables to jointly model cross-variable dependence and per-variable distributions. Pairwise statistical features derived from classical asymmetry measures are injected through statistics-conditioned attention, guiding the model toward known causal signals. A factorized decoder separates edge existence from direction, while a triangular refinement module enables reasoning over higher-order causal motifs such as chains and colliders. In addition, a dedicated confounder module based on learnable latent tokens explicitly models hidden common causes, and the model explicitly handles missing data via its masked input representation. To our knowledge, FoundCause is the first amortized causal discovery approach to explicitly model latent confounding. FoundCause outperforms 11 classical non-amortized methods (e.g., PC, GES, NOTEARS-style optimization) and 4 amortized causal discovery methods on 15 real-world datasets, achieving +9.6% improvement in $F_1$, +1.2% in AUROC, and an 18.9% reduction in structural Hamming distance relative to the strongest non-amortized methods, while performing inference in a single forward pass.
翻译:从观测数据中因果发现仍具挑战性,原因在于需在无干预条件下恢复有向结构与潜在混杂因素。我们提出FoundCause——一种完全基于合成数据训练的摊销式因果发现模型,通过单次前向传播即可将数据集直接映射至因果图。通过从大规模模拟结构因果模型中学习,FoundCause捕捉到超越单个数据集的可迁移统计模式。该架构融合了因果发现的若干关键归纳偏置:采用基于交替注意力机制的置换不变Transformer编码器,在样本与变量维度实现跨变量依赖与单变量分布的联合建模;通过统计条件注意力注入源于经典非对称度量的成对统计特征,引导模型捕捉已知因果信号;分解式解码器将边存在性与方向性相分离,而三角精修模块支持对链式、对撞结构等高阶因果模式的推理。此外,基于可学习隐变量标记的专用混杂模块显式建模隐藏共同原因,模型通过掩码输入表示处理缺失数据。据我们所知,FoundCause是首个显式建模隐混杂因素的摊销式因果发现方法。在15个真实数据集上,FoundCause超越11种经典非摊销方法(如PC、GES、NOTEARS式优化)及4种摊销式因果发现方法,相较于最强非摊销方法实现$F_1$值提升+9.6%,AUROC提升+1.2%,结构汉明距离降低18.9%,且推理过程仅需单次前向传播。