This paper provides a fresh view of the neural network (NN) data flow problem, i.e., identifying the NN connections that are most important for the performance of the full model, through the lens of graph theory. Understanding the NN data flow provides a tool for symbolic NN analysis, e.g.,~robustness analysis or model repair. Unlike the standard approach to NN data flow analysis, which is based on information theory, we employ the notion of graph curvature, specifically Ollivier-Ricci curvature (ORC). The ORC has been successfully used to identify important graph edges in various domains such as road traffic analysis, biological and social networks. In particular, edges with negative ORC are considered bottlenecks and as such are critical to the graph's overall connectivity, whereas positive-ORC edges are not essential. We use this intuition for the case of NNs as well: we 1)~construct a graph induced by the NN structure and introduce the notion of neural curvature (NC) based on the ORC; 2)~calculate curvatures based on activation patterns for a set of input examples; 3)~aim to demonstrate that NC can indeed be used to rank edges according to their importance for the overall NN functionality. We evaluate our method through pruning experiments and show that removing negative-ORC edges quickly degrades the overall NN performance, whereas positive-ORC edges have little impact. The proposed method is evaluated on a variety of models trained on three image datasets, namely MNIST, CIFAR-10 and CIFAR-100. The results indicate that our method can identify a larger number of unimportant edges as compared to state-of-the-art pruning methods.
翻译:本文通过图论视角,为神经网络数据流问题——即识别对完整模型性能最为关键的神经网络连接——提供了全新见解。理解神经网络数据流为符号化神经网络分析(如鲁棒性分析或模型修复)提供了工具。与基于信息论的传统神经网络数据流分析方法不同,我们采用图曲率概念,特别是Ollivier-Ricci曲率(ORC)。ORC已成功应用于道路交通分析、生物网络和社会网络等多个领域的重要边识别。具体而言,具有负ORC的边被视为瓶颈,对图的整体连通性至关重要,而正ORC边则非必需。我们将这一思路拓展至神经网络:1)构建由神经网络结构导出的图,基于ORC提出神经曲率(NC)概念;2)根据输入样本集的激活模式计算曲率;3)旨在证明NC确实可用于依据边对神经网络整体功能的重要性进行排序。通过剪枝实验评估本方法,结果表明移除负ORC边会迅速降低神经网络整体性能,而正ORC边影响甚微。所提方法在MNIST、CIFAR-10和CIFAR-100三个图像数据集训练的各种模型上进行评估。实验结果显示,相较于前沿剪枝方法,本方法能识别更多非重要边。