Deep learning (DL) models for tabular data problems (e.g. classification, regression) are currently receiving increasingly more attention from researchers. However, despite the recent efforts, the non-DL algorithms based on gradient-boosted decision trees (GBDT) remain a strong go-to solution for these problems. One of the research directions aimed at improving the position of tabular DL involves designing so-called retrieval-augmented models. For a target object, such models retrieve other objects (e.g. the nearest neighbors) from the available training data and use their features and labels to make a better prediction. In this work, we present TabR -- essentially, a feed-forward network with a custom k-Nearest-Neighbors-like component in the middle. On a set of public benchmarks with datasets up to several million objects, TabR marks a big step forward for tabular DL: it demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed "GBDT-friendly" benchmark (see Figure 1). Among the important findings and technical details powering TabR, the main ones lie in the attention-like mechanism that is responsible for retrieving the nearest neighbors and extracting valuable signal from them. In addition to the much higher performance, TabR is simple and significantly more efficient compared to prior retrieval-based tabular DL models.
翻译:针对表格数据问题(如分类、回归)的深度学习模型正日益受到研究者的关注。然而,尽管近期付出了诸多努力,基于梯度提升决策树的非深度学习算法仍是解决此类问题的强有力首选方案。旨在提升表格深度学习地位的研究方向之一涉及设计所谓的检索增强模型:对于目标对象,此类模型从可用训练数据中检索其他对象(如最近邻),并利用其特征和标签进行更优预测。本文提出TabR——本质上是一种在中间层嵌入自定义类k-最近邻组件的前馈网络。在包含多达数百万个对象数据集的公开基准测试中,TabR标志着表格深度学习的重大进步:它在表格深度学习模型中取得了最佳平均性能,在多个数据集上成为新的最优方法,甚至在新近提出的"GBDT友好型"基准测试中超越了GBDT模型(见图1)。在驱动TabR的关键发现与技术细节中,核心在于其类注意力机制——该机制负责检索最近邻并从其中提取有价值信号。除性能大幅提升外,TabR相较于先前基于检索的表格深度学习模型更为简单且效率显著更高。