In Federated Learning, a global model is learned by aggregating model updates computed at a set of independent client nodes, to reduce communication costs multiple gradient steps are performed at each node prior to aggregation. A key challenge in this setting is data heterogeneity across clients resulting in differing local objectives which can lead clients to overly minimize their own local objective, diverging from the global solution. We demonstrate that individual client models experience a catastrophic forgetting with respect to data from other clients and propose an efficient approach that modifies the cross-entropy objective on a per-client basis by re-weighting the softmax logits prior to computing the loss. This approach shields classes outside a client's label set from abrupt representation change and we empirically demonstrate it can alleviate client forgetting and provide consistent improvements to standard federated learning algorithms. Our method is particularly beneficial under the most challenging federated learning settings where data heterogeneity is high and client participation in each round is low.
翻译:在联邦学习中,全局模型通过聚合一组独立客户端节点上计算的模型更新来学习。为降低通信成本,每个节点在聚合前会执行多个梯度步骤。该场景下的关键挑战是客户端之间的数据异质性会导致不同的局部目标,这可能使客户端过度最小化自身局部目标,从而偏离全局解。我们证明,单个客户端模型会经历关于其他客户端数据的灾难性遗忘,并提出一种高效方法:通过重新加权Softmax逻辑值(softmax logits)来修改每个客户端的交叉熵目标。该方法能够保护客户端标签集之外的类别免受表征的突变。实验表明,该方法可缓解客户端遗忘,并为标准联邦学习算法带来持续改进。在数据异质性高且每轮客户端参与率低的极具挑战性的联邦学习场景中,我们的方法尤其有效。