Estimating the structure of directed acyclic graphs (DAGs) from observational data remains a significant challenge in machine learning. Most research in this area concentrates on learning a single DAG for the entire population. This paper considers an alternative setting where the graph structure varies across individuals based on available "contextual" features. We tackle this contextual DAG problem via a neural network that maps the contextual features to a DAG, represented as a weighted adjacency matrix. The neural network is equipped with a novel projection layer that ensures the output matrices are sparse and satisfy a recently developed characterization of acyclicity. We devise a scalable computational framework for learning contextual DAGs and provide a convergence guarantee and an analytical gradient for backpropagating through the projection layer. Our experiments suggest that the new approach can recover the true context-specific graph where existing approaches fail.
翻译:从观测数据中估计有向无环图(DAGs)的结构仍是机器学习领域的一项重大挑战。该领域的大多数研究集中于为整体群体学习单一DAG。本文考虑另一种设定:图结构基于可用"上下文"特征在不同个体间发生变化。我们通过一种神经网络来解决这一上下文DAG问题,该网络将上下文特征映射为以加权邻接矩阵表示的DAG。该神经网络配备了一个新型投影层,确保输出矩阵具有稀疏性,并满足近期提出的无环性特征化条件。我们设计了一个可扩展的计算框架用于学习上下文DAG,并提供了收敛性保证以及通过投影层进行反向传播的解析梯度。实验表明,新方法能够恢复现有方法失败的真正上下文特定图结构。