The canonical formulation of federated learning treats it as a distributed optimization problem where the model parameters are optimized against a global loss function that decomposes across client loss functions. A recent alternative formulation instead treats federated learning as a distributed inference problem, where the goal is to infer a global posterior from partitioned client data (Al-Shedivat et al., 2021). This paper extends the inference view and describes a variational inference formulation of federated learning where the goal is to find a global variational posterior that well-approximates the true posterior. This naturally motivates an expectation propagation approach to federated learning (FedEP), where approximations to the global posterior are iteratively refined through probabilistic message-passing between the central server and the clients. We conduct an extensive empirical study across various algorithmic considerations and describe practical strategies for scaling up expectation propagation to the modern federated setting. We apply FedEP on standard federated learning benchmarks and find that it outperforms strong baselines in terms of both convergence speed and accuracy.
翻译:联邦学习的经典表述将其视为一个分布式优化问题,其中模型参数针对一个跨客户端损失函数分解的全局损失函数进行优化。近期一种替代性表述则将联邦学习视为分布式推断问题,目标是从分区的客户端数据中推断全局后验分布(Al-Shedivat 等,2021)。本文延伸了这一推断视角,提出联邦学习的变分推断框架,其核心目标是寻找能够良好逼近真实后验分布的全局变分后验。该框架自然地导向一种基于期望传播的联邦学习方法(FedEP),通过中心服务器与客户端之间的概率消息传递机制迭代优化全局后验的近似。我们针对各类算法考量因素开展了系统的实证研究,并描述了将期望传播扩展至现代联邦场景的实用策略。在标准联邦学习基准测试中应用 FedEP 后发现,该方法在收敛速度和准确率方面均优于强基线方法。