Despite recent advancements in federated learning (FL), the integration of generative models into FL has been limited due to challenges such as high communication costs and unstable training in heterogeneous data environments. To address these issues, we propose PRISM, a FL framework tailored for generative models that ensures (i) stable performance in heterogeneous data distributions and (ii) resource efficiency in terms of communication cost and final model size. The key of our method is to search for an optimal stochastic binary mask for a random network rather than updating the model weights, identifying a sparse subnetwork with high generative performance; i.e., a ``strong lottery ticket''. By communicating binary masks in a stochastic manner, PRISM minimizes communication overhead. This approach, combined with the utilization of maximum mean discrepancy (MMD) loss and a mask-aware dynamic moving average aggregation method (MADA) on the server side, facilitates stable and strong generative capabilities by mitigating local divergence in FL scenarios. Moreover, thanks to its sparsifying characteristic, PRISM yields a lightweight model without extra pruning or quantization, making it ideal for environments such as edge devices. Experiments on MNIST, FMNIST, CelebA, and CIFAR10 demonstrate that PRISM outperforms existing methods, while maintaining privacy with minimal communication costs. PRISM is the first to successfully generate images under challenging non-IID and privacy-preserving FL environments on complex datasets, where previous methods have struggled.
翻译:尽管联邦学习(FL)近期取得了进展,但由于高通信成本和异构数据环境中的训练不稳定等挑战,生成模型在FL中的集成仍受到限制。为解决这些问题,我们提出了PRISM——一个专为生成模型定制的FL框架,确保(i)在异构数据分布下的稳定性能,以及(ii)在通信成本和最终模型大小方面的资源效率。我们方法的核心是为随机网络搜索最优的随机二元掩码而非更新模型权重,从而识别出具有高生成性能的稀疏子网络,即“强彩票假设”。通过以随机方式传输二元掩码,PRISM最小化了通信开销。该方法结合服务器端最大平均差异(MMD)损失函数和掩码感知动态移动平均聚合方法(MADA),通过缓解FL场景中的局部发散,促进了稳定且强大的生成能力。此外,得益于其稀疏化特性,PRISM无需额外剪枝或量化即可生成轻量级模型,使其非常适合边缘设备等环境。在MNIST、FMNIST、CelebA和CIFAR10上的实验表明,PRISM在保持隐私且通信成本极低的同时,性能优于现有方法。在复杂数据集上具有挑战性的非独立同分布和隐私保护FL环境中,PRISM首次成功实现了图像生成,而先前方法在此类场景中均存在困难。