Decision Trees (DTs) constitute one of the major highly non-linear AI models, valued, e.g., for their efficiency on tabular data. Learning accurate DTs is, however, complicated, especially for oblique DTs, and does take a significant training time. Further, DTs suffer from overfitting, e.g., they proverbially "do not generalize" in regression tasks. Recently, some works proposed ways to make (oblique) DTs differentiable. This enables highly efficient gradient-descent algorithms to be used to learn DTs. It also enables generalizing capabilities by learning regressors at the leaves simultaneously with the decisions in the tree. Prior approaches to making DTs differentiable rely either on probabilistic approximations at the tree's internal nodes (soft DTs) or on approximations in gradient computation at the internal node (quantized gradient descent). In this work, we propose DTSemNet, a novel semantically equivalent and invertible encoding for (hard, oblique) DTs as Neural Networks (NNs), that uses standard vanilla gradient descent. Experiments across various classification and regression benchmarks show that oblique DTs learned using DTSemNet are more accurate than oblique DTs of similar size learned using state-of-the-art techniques. Further, DT training time is significantly reduced. We also experimentally demonstrate that DTSemNet can learn DT policies as efficiently as NN policies in the Reinforcement Learning (RL) setup with physical inputs (dimensions $\leq32$). The code is available at https://github.com/CPS-research-group/dtsemnet.
翻译:决策树(DTs)作为一类高度非线性的人工智能模型,因其在处理表格数据时的高效性而备受重视。然而,学习精确的决策树,尤其是斜决策树,过程复杂且需要大量训练时间。此外,决策树容易过拟合,例如在回归任务中普遍存在“泛化能力不足”的问题。最近,一些研究提出了使(斜)决策树可微分的方法。这使得能够使用高效的梯度下降算法来学习决策树,同时通过在树的叶子节点学习回归器,并结合树中的决策过程,提升了模型的泛化能力。先前实现决策树可微分的方法主要依赖于在树内部节点采用概率近似(软决策树)或在内部节点梯度计算中采用近似方法(量化梯度下降)。在本研究中,我们提出了DTSemNet,这是一种新颖的语义等价且可逆的编码方法,将(硬、斜)决策树编码为神经网络(NNs),并使用标准的普通梯度下降进行训练。在多种分类和回归基准测试上的实验表明,使用DTSemNet学习的斜决策树在相似规模下,比使用现有先进技术学习的斜决策树具有更高的准确性。此外,决策树的训练时间显著减少。我们还通过实验证明,在具有物理输入(维度≤32)的强化学习(RL)设置中,DTSemNet能够像学习神经网络策略一样高效地学习决策树策略。代码可在https://github.com/CPS-research-group/dtsemnet获取。