We introduce ProxSkip -- a surprisingly simple and provably efficient method for minimizing the sum of a smooth ($f$) and an expensive nonsmooth proximable ($\psi$) function. The canonical approach to solving such problems is via the proximal gradient descent (ProxGD) algorithm, which is based on the evaluation of the gradient of $f$ and the prox operator of $\psi$ in each iteration. In this work we are specifically interested in the regime in which the evaluation of prox is costly relative to the evaluation of the gradient, which is the case in many applications. ProxSkip allows for the expensive prox operator to be skipped in most iterations: while its iteration complexity is $\mathcal{O}\left(\kappa \log \frac{1}{\varepsilon}\right)$, where $\kappa$ is the condition number of $f$, the number of prox evaluations is $\mathcal{O}\left(\sqrt{\kappa} \log \frac{1}{\varepsilon}\right)$ only. Our main motivation comes from federated learning, where evaluation of the gradient operator corresponds to taking a local GD step independently on all devices, and evaluation of prox corresponds to (expensive) communication in the form of gradient averaging. In this context, ProxSkip offers an effective acceleration of communication complexity. Unlike other local gradient-type methods, such as FedAvg, SCAFFOLD, S-Local-GD and FedLin, whose theoretical communication complexity is worse than, or at best matching, that of vanilla GD in the heterogeneous data regime, we obtain a provable and large improvement without any heterogeneity-bounding assumptions.
翻译:我们提出ProxSkip——一种极其简单且可证明高效的方法,用于最小化光滑函数($f$)与昂贵非光滑近端函数($\psi$)之和。解决此类问题的经典方法是近端梯度下降(ProxGD)算法,该算法在每次迭代中需同时计算$f$的梯度与$\psi$的近端算子。本文特别关注近端算子计算代价远高于梯度计算的场景——这在众多应用中普遍存在。ProxSkip允许在大多数迭代中跳过昂贵的近端算子:虽然其迭代复杂度为$\mathcal{O}\left(\kappa \log \frac{1}{\varepsilon}\right)$(其中$\kappa$是$f$的条件数),但近端算子计算次数仅为$\mathcal{O}\left(\sqrt{\kappa} \log \frac{1}{\varepsilon}\right)$。我们的主要动机源于联邦学习:梯度算子计算对应各设备独立执行的局部梯度下降步骤,而近端算子计算则对应梯度平均这一昂贵的通信操作。在此背景下,ProxSkip实现了通信复杂度的有效加速。与FedAvg、SCAFFOLD、S-Local-GD和FedLin等局部梯度方法不同——这些方法在异质性数据场景下的理论通信复杂度差于或至多持平于普通梯度下降法——我们在无需任何异质性边界假设的情况下,获得了可证明且显著的性能提升。