Currently, it is hard to reap the benefits of deep learning for Bayesian methods, which allow the explicit specification of prior knowledge and accurately capture model uncertainty. We present Prior-Data Fitted Networks (PFNs). PFNs leverage large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. Presented with a set of samples from a new supervised learning task as input, PFNs make probabilistic predictions for arbitrary other data points in a single forward propagation, having learned to approximate Bayesian inference. We demonstrate that PFNs can near-perfectly mimic Gaussian processes and also enable efficient Bayesian inference for intractable problems, with over 200-fold speedups in multiple setups compared to current methods. We obtain strong results in very diverse areas such as Gaussian process regression, Bayesian neural networks, classification for small tabular data sets, and few-shot image classification, demonstrating the generality of PFNs. Code and trained PFNs are released at https://github.com/automl/TransformersCanDoBayesianInference.
翻译:目前,深度学习方法难以直接受益于贝叶斯方法——后者能够显式指定先验知识并准确捕捉模型不确定性。我们提出先验数据拟合网络(Prior-Data Fitted Networks, PFNs)。PFNs 利用大规模机器学习技术来近似大量后验分布,其运行唯一要求是能够从监督学习任务(或函数)的先验分布中进行采样。我们将后验近似目标重新表述为一个带集合型输入的监督分类问题:从先验中重复抽取任务(或函数),从中采样数据点及其标签,遮罩其中一个标签,并基于剩余数据点的集合型输入学习对该标签的概率预测。当给定新监督学习任务的样本集合作为输入时,PFNs 通过单次前向传播即可对任意其他数据点进行概率预测,从而学会近似贝叶斯推断。我们证明 PFNs 能近乎完美地模仿高斯过程,并能为不可解问题实现高效贝叶斯推断——在多种设置下相比现有方法获得超过200倍加速。我们在高斯过程回归、贝叶斯神经网络、小表格数据集分类以及少样本图像分类等极为多样的领域均取得了强劲结果,彰显了 PFNs 的通用性。代码与预训练模型已发布在 https://github.com/automl/TransformersCanDoBayesianInference。