Federated Learning (FL) has emerged as a new paradigm for training machine learning models distributively without sacrificing data security and privacy. Learning models on edge devices such as mobile phones is one of the most common use cases for FL. However, Non-identical independent distributed~(non-IID) data in edge devices easily leads to training failures. Especially, over-parameterized machine learning models can easily be over-fitted on such data, hence, resulting in inefficient federated learning and poor model performance. To overcome the over-fitting issue, we proposed an adaptive dynamic pruning approach for FL, which can dynamically slim the model by dropping out unimportant parameters, hence, preventing over-fittings. Since the machine learning model's parameters react differently for different training samples, adaptive dynamic pruning will evaluate the salience of the model's parameter according to the input training sample, and only retain the salient parameter's gradients when doing back-propagation. We performed comprehensive experiments to evaluate our approach. The results show that our approach by removing the redundant parameters in neural networks can significantly reduce the over-fitting issue and greatly improves the training efficiency. In particular, when training the ResNet-32 on CIFAR-10, our approach reduces the communication cost by 57\%. We further demonstrate the inference acceleration capability of the proposed algorithm. Our approach reduces up to 50\% FLOPs inference of DNNs on edge devices while maintaining the model's quality.
翻译:联邦学习(FL)已成为一种分布式训练机器学习模型的新范式,能够在保护数据安全与隐私的前提下进行模型训练。在移动电话等边缘设备上学习模型是FL最常见的应用场景之一。然而,边缘设备上非独立同分布(non-IID)的数据极易导致训练失败。尤其当使用过参数化的机器学习模型时,此类数据容易引发过拟合现象,从而导致联邦学习效率低下及模型性能不佳。为克服过拟合问题,我们提出了一种面向FL的自适应动态剪枝方法,该方法通过剔除不重要的参数动态精简模型,从而有效防止过拟合。由于机器学习模型参数对不同训练样本的响应存在差异,自适应动态剪枝会根据输入训练样本评估模型参数的重要性,并在反向传播时仅保留重要参数的梯度。我们开展了全面的实验以评估所提方法。结果表明,通过移除神经网络中的冗余参数,该方法能够显著缓解过拟合问题,并大幅提升训练效率。特别地,在CIFAR-10数据集上训练ResNet-32时,本方法将通信成本降低了57%。我们进一步展示了该算法的推理加速能力。所提方法在保持模型质量的同时,可将边缘设备上深度神经网络的推理计算量(FLOPs)减少高达50%。