Deep learning (DL) models for tabular data problems are receiving increasingly more attention, while the algorithms based on gradient-boosted decision trees (GBDT) remain a strong go-to solution. Following the recent trends in other domains, such as natural language processing and computer vision, several retrieval-augmented tabular DL models have been recently proposed. For a given target object, a retrieval-based model retrieves other relevant objects, such as the nearest neighbors, from the available (training) data and uses their features or even labels to make a better prediction. However, we show that the existing retrieval-based tabular DL solutions provide only minor, if any, benefits over the properly tuned simple retrieval-free baselines. Thus, it remains unclear whether the retrieval-based approach is a worthy direction for tabular DL. In this work, we give a strong positive answer to this question. We start by incrementally augmenting a simple feed-forward architecture with an attention-like retrieval component similar to those of many (tabular) retrieval-based models. Then, we highlight several details of the attention mechanism that turn out to have a massive impact on the performance on tabular data problems, but that were not explored in prior work. As a result, we design TabR -- a simple retrieval-based tabular DL model which, on a set of public benchmarks, demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed ``GBDT-friendly'' benchmark (see the first figure).
翻译:针对表格数据问题的深度学习模型正受到越来越多的关注,而基于梯度提升决策树的算法仍然是强有力的首选解决方案。受自然语言处理和计算机视觉等其他领域近期趋势的启发,最近提出了一些检索增强型表格深度学习模型。对于给定的目标对象,基于检索的模型会从可用(训练)数据中检索其他相关对象(例如最近邻),并利用它们的特征甚至标签来做出更好的预测。然而,我们发现,现有基于检索的表格深度学习解决方案相较于适当调优的简单无检索基线模型,仅能提供微小的改进(若有)。因此,基于检索的方法是否值得成为表格深度学习的一个方向尚不明确。在本工作中,我们对这个问题给出了强有力的肯定回答。我们首先通过逐步增加一个类似于许多(表格)检索模型所使用的注意力型检索组件,来增强一个简单的前馈架构。接着,我们指出了注意力机制中的若干细节,这些细节对表格数据问题的性能有着巨大影响,但此前的研究并未探索过。最终,我们设计了TabR——一个简单的基于检索的表格深度学习模型,在公开基准测试集上,该模型在表格深度学习模型中取得了最佳平均性能,在多个数据集上达到了新的最先进水平,甚至在近期提出的“GBDT友好型”基准测试(见图1)中超越了GBDT模型。