Although Shapley additive explanations (SHAP) can be computed in polynomial time for simple models like decision trees, they unfortunately become NP-hard to compute for more expressive black-box models like neural networks - where generating explanations is often most critical. In this work, we analyze the problem of computing SHAP explanations for *Tensor Networks (TNs)*, a broader and more expressive class of models than those for which current exact SHAP algorithms are known to hold, and which is widely used for neural network abstraction and compression. First, we introduce a general framework for computing provably exact SHAP explanations for general TNs with arbitrary structures. Interestingly, we show that, when TNs are restricted to a *Tensor Train (TT)* structure, SHAP computation can be performed in *poly-logarithmic* time using *parallel* computation. Thanks to the expressiveness power of TTs, this complexity result can be generalized to many other popular ML models such as decision trees, tree ensembles, linear models, and linear RNNs, therefore tightening previously reported complexity results for these families of models. Finally, by leveraging reductions of binarized neural networks to Tensor Network representations, we demonstrate that SHAP computation can become *efficiently tractable* when the network's *width* is fixed, while it remains computationally hard even with constant *depth*. This highlights an important insight: for this class of models, width - rather than depth - emerges as the primary computational bottleneck in SHAP computation.
翻译:尽管对于决策树等简单模型,Shapley加性解释(SHAP)可以在多项式时间内计算,但对于更具表达力的黑盒模型(如神经网络)——其解释生成往往最为关键——SHAP的计算却不幸成为NP难问题。本研究分析了*张量网络(TNs)*的SHAP解释计算问题,这是一类比现有精确SHAP算法适用范围更广、表达力更强的模型,广泛用于神经网络的抽象与压缩。首先,我们提出一个通用框架,用于计算任意结构张量网络的精确可证明SHAP解释。有趣的是,我们证明,当张量网络被限制为*张量列车(TT)*结构时,SHAP计算可通过*并行*计算在*多对数*时间内完成。借助TT的表达能力,这一复杂度结果可推广至许多其他主流机器学习模型,如决策树、树集成、线性模型和线性循环神经网络,从而收紧这些模型族此前报告的复杂度结果。最后,通过将二值化神经网络简化为张量网络表示,我们证明,当网络的*宽度*固定时,SHAP计算可以变得*高效可解*,而即使深度恒定,它仍然保持计算困难。这揭示了一个重要洞见:对于这类模型,宽度——而非深度——成为SHAP计算的主要计算瓶颈。