Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The method is available under: https://github.com/s-marton/GradTree
翻译:决策树因其高度可解释性而被广泛用于许多机器学习任务。然而,从数据中学习决策树是一个困难的优化问题,因为它非凸且不可微。因此,常见方法采用贪心增长算法学习决策树,该算法在每个内部节点局部最小化不纯度。遗憾的是,这种贪心过程可能导致不精确的树。本文提出一种新颖方法,用于通过梯度下降学习硬性、轴对齐的决策树。所提方法在密集决策树表示上使用反向传播结合直通算子,以联合优化所有树参数。我们的方法在二分类基准测试中优于现有方法,并在多分类任务中取得具有竞争力的结果。该方法可通过以下链接获取:https://github.com/s-marton/GradTree