Federated learning-assisted edge intelligence enables privacy protection in modern intelligent services. However, not independent and identically distributed (non-IID) distribution among edge clients can impair the local model performance. The existing single prototype-based strategy represents a class by using the mean of the feature space. However, feature spaces are usually not clustered, and a single prototype may not represent a class well. Motivated by this, this paper proposes a multi-prototype federated contrastive learning approach (MP-FedCL) which demonstrates the effectiveness of using a multi-prototype strategy over a single-prototype under non-IID settings, including both label and feature skewness. Specifically, a multi-prototype computation strategy based on \textit{k-means} is first proposed to capture different embedding representations for each class space, using multiple prototypes ($k$ centroids) to represent a class in the embedding space. In each global round, the computed multiple prototypes and their respective model parameters are sent to the edge server for aggregation into a global prototype pool, which is then sent back to all clients to guide their local training. Finally, local training for each client minimizes their own supervised learning tasks and learns from shared prototypes in the global prototype pool through supervised contrastive learning, which encourages them to learn knowledge related to their own class from others and reduces the absorption of unrelated knowledge in each global iteration. Experimental results on MNIST, Digit-5, Office-10, and DomainNet show that our method outperforms multiple baselines, with an average test accuracy improvement of about 4.6\% and 10.4\% under feature and label non-IID distributions, respectively.
翻译:联邦学习辅助的边缘智能技术可在现代智能服务中实现隐私保护。然而,边缘客户端之间的非独立同分布(non-IID)数据会损害本地模型性能。现有基于单原型策略的方法使用特征空间均值表示类别,但由于特征空间通常未聚类,单原型可能无法充分表征类别。受此启发,本文提出多原型联邦对比学习方法(MP-FedCL),证明在包含标签偏移和特征偏移的非IID场景下,多原型策略相较于单原型策略具有显著优势。具体而言,首先提出基于\textit{k-means}的多原型计算策略,通过为每个类别空间捕获不同嵌入表示,使用多个原型($k$个质心)表示嵌入空间中的类别。在每个全局轮次中,将计算所得的多原型及其对应模型参数发送至边缘服务器进行聚合形成全局原型池,随后下发至所有客户端指导其本地训练。最后,各客户端通过最小化自身监督学习任务,并借助监督对比学习从全局原型池共享原型中学习——这种机制促使客户端在每次全局迭代中从其他客户端获取与自身类别相关的知识,同时减少无关知识的吸收。在MNIST、Digit-5、Office-10和DomainNet上的实验表明,本方法在特征和标签非IID分布下分别实现平均测试准确率提升约4.6%和10.4%,优于多种基线方法。