Distributed learning is widely used for training large models on large datasets by distributing parts of the model or dataset across multiple devices and aggregating the computed results for subsequent computations or parameter updates. Existing communication algorithms for distributed learning such as ring all-reduce result in heavy communication overhead between servers. Since communication in large-scale systems uses optical fibers, we propose an Optical In-Network-Computing (OptINC) architecture to offload the computation in servers onto the optical interconnects. To execute gradient averaging and quantization in the optical domain, we incorporate optical devices such as Mach-Zehnder-Interferometers (MZIs) into the interconnects. Such a de facto optical neural network (ONN) can effectively reduce the communication overhead in existing distributed training solutions. To reduce dataset complexity for training this neural network, a preprocessing algorithm implemented in the optical domain is also proposed. Hardware cost is lowered by approximating the weight matrices of the optical neural network with unitary and diagonal matrices, while the accuracy is maintained by a proposed hardware-aware training algorithm. The proposed solution was evaluated on real distributed learning tasks, including ResNet50 on CIFAR-100, and a LLaMA-based network on Wikipedia-1B. In both cases, the proposed framework can achieve comparable training accuracy to the ring all-reduce baseline, while eliminating communication overhead.
翻译:分布式学习通过将模型或数据集的部分分布到多个设备上,并聚合计算结果以进行后续计算或参数更新,从而广泛应用于在大型数据集上训练大型模型。现有的分布式学习通信算法(如环形全规约)会导致服务器之间产生严重的通信开销。由于大规模系统中的通信使用光纤,我们提出了一种光学网络内计算(OptINC)架构,将服务器中的计算卸载到光学互连上。为了在光学域中执行梯度平均和量化,我们将马赫-曾德尔干涉仪等光学器件集成到互连中。这种事实上的光学神经网络可以有效降低现有分布式训练解决方案中的通信开销。为降低训练该神经网络所需的数据集复杂度,还提出了一种在光学域中实现的预处理算法。通过使用酉矩阵和对角矩阵近似光学神经网络的权重矩阵,降低了硬件成本,同时通过提出的硬件感知训练算法保持了准确性。所提出的解决方案在真实的分布式学习任务上进行了评估,包括在CIFAR-100上的ResNet50和基于LLaMA的网络在Wikipedia-1B上的训练。在两种情况下,所提出的框架均可达到与环形全规约基线相当的训练精度,同时消除了通信开销。