Foundation models have brought changes to the landscape of machine learning, demonstrating sparks of human-level intelligence across a diverse array of tasks. However, a gap persists in complex tasks such as causal inference, primarily due to challenges associated with intricate reasoning steps and high numerical precision requirements. In this work, we take a first step towards building causally-aware foundation models for complex tasks. We propose a novel, theoretically sound method called Causal Inference with Attention (CInA), which utilizes multiple unlabeled datasets to perform self-supervised causal learning, and subsequently enables zero-shot causal inference on unseen tasks with new data. This is based on our theoretical results that demonstrate the primal-dual connection between optimal covariate balancing and self-attention, facilitating zero-shot causal inference through the final layer of a trained transformer-type architecture. We demonstrate empirically that our approach CInA effectively generalizes to out-of-distribution datasets and various real-world datasets, matching or even surpassing traditional per-dataset causal inference methodologies.
翻译:基础模型已引发机器学习领域的变革,在多样任务中展现出接近人类智能的火花。然而,在因果推断等复杂任务中仍存在差距,这主要源于推理步骤的复杂性和高数值精度要求。本文首次探索构建面向复杂任务的因果感知基础模型。我们提出一种创新的、具有理论保障的方法——基于注意力的因果推断(CInA),该方法利用多个无标注数据集实现自监督因果学习,进而支持对新数据中的未见任务进行零样本因果推断。这一成果源于我们的理论发现:最优协变量平衡与自注意力机制之间存在原-对偶关联,这使得通过训练后的Transformer架构最终层实现零样本因果推断成为可能。实验表明,CInA方法能有效泛化至分布外数据集及多种真实世界数据集,其性能可匹配甚至超越传统逐数据集因果推断方法。