Foundation models have brought changes to the landscape of machine learning, demonstrating sparks of human-level intelligence across a diverse array of tasks. However, a gap persists in complex tasks such as causal inference, primarily due to challenges associated with intricate reasoning steps and high numerical precision requirements. In this work, we take a first step towards building causally-aware foundation models for treatment effect estimations. We propose a novel, theoretically justified method called Causal Inference with Attention (CInA), which utilizes multiple unlabeled datasets to perform self-supervised causal learning, and subsequently enables zero-shot causal inference on unseen tasks with new data. This is based on our theoretical results that demonstrate the primal-dual connection between optimal covariate balancing and self-attention, facilitating zero-shot causal inference through the final layer of a trained transformer-type architecture. We demonstrate empirically that CInA effectively generalizes to out-of-distribution datasets and various real-world datasets, matching or even surpassing traditional per-dataset methodologies. These results provide compelling evidence that our method has the potential to serve as a stepping stone for the development of causal foundation models.
翻译:基础模型已改变了机器学习的发展格局,在多样化任务中展现出类人智能的火花。然而,在因果推断等复杂任务中仍存在差距,这主要源于复杂推理步骤和高数值精度要求带来的挑战。本研究为构建面向处理效应估计的因果感知基础模型迈出了第一步。我们提出一种具有理论依据的创新方法——基于注意力的因果推断(CInA),该方法利用多个未标记数据集进行自监督因果学习,进而能够在未见任务的新数据上实现零样本因果推断。这一方法的理论基础在于我们证明了最优协变量平衡与自注意力机制之间存在对偶关系,通过训练完成的Transformer型架构的最终层即可实现零样本因果推断。我们通过实证表明,CInA能有效泛化至分布外数据集及各类现实数据集,其性能与传统按数据集定制的方法相当甚至更优。这些结果为我们的方法有望成为因果基础模型发展的基石提供了有力证据。