Sharding a large machine learning model across multiple devices to balance the costs is important in distributed training. This is challenging because partitioning is NP-hard, and estimating the costs accurately and efficiently is difficult. In this work, we explore a "pre-train, and search" paradigm for efficient sharding. The idea is to pre-train a universal and once-for-all neural network to predict the costs of all the possible shards, which serves as an efficient sharding simulator. Built upon this pre-trained cost model, we then perform an online search to identify the best sharding plans given any specific sharding task. We instantiate this idea in deep learning recommendation models (DLRMs) and propose NeuroShard for embedding table sharding. NeuroShard pre-trains neural cost models on augmented tables to cover various sharding scenarios. Then it identifies the best column-wise and table-wise sharding plans with beam search and greedy grid search, respectively. Experiments show that NeuroShard significantly and consistently outperforms the state-of-the-art on the benchmark sharding dataset, achieving up to 23.8% improvement. When deployed in an ultra-large production DLRM with multi-terabyte embedding tables, NeuroShard achieves 11.6% improvement in embedding costs over the state-of-the-art, which translates to 6.6% end-to-end training throughput improvement. To facilitate future research of the "pre-train, and search" paradigm in ML for Systems, we open-source our code at https://github.com/daochenzha/neuroshard
翻译:将大型机器学习模型在多设备间进行分片以平衡训练代价,是分布式训练中的重要问题。该问题具有挑战性,因为划分是NP难的,且准确高效地估计代价十分困难。本文探索了一种"预训练-搜索"范式以实现高效分片。其核心思想是预训练一个通用且一次性的神经网络,用于预测所有可能分片的代价,从而作为高效的分片模拟器。基于此预训练代价模型,我们针对任意特定分片任务执行在线搜索,以确定最优分片方案。我们在深度学习推荐模型(DLRM)中实例化该思想,并提出NeuroShard用于嵌入表分片。NeuroShard在增强表上预训练神经代价模型以覆盖多种分片场景,随后分别通过束搜索和贪心网格搜索确定最优的列级和表级分片方案。实验表明,NeuroShard在基准分片数据集上显著且一致地优于现有最优方法,性能提升最高达23.8%。当部署于包含数太字节嵌入表的超大规模生产级DLRM时,NeuroShard在嵌入代价上较现有最优方法实现11.6%的改进,转化为6.6%的端到端训练吞吐量提升。为促进机器学习系统领域"预训练-搜索"范式的未来研究,我们已在https://github.com/daochenzha/neuroshard 开源代码。