This work proposes a Momentum-Enabled Kronecker-Factor-Based Optimizer Using Rank-1 updates, called MKOR, that improves the training time and convergence properties of deep neural networks (DNNs). Second-order techniques, while enjoying higher convergence rates vs first-order counterparts, have cubic complexity with respect to either the model size and/or the training batch size. Hence they exhibit poor scalability and performance in transformer models, e.g. large language models (LLMs), because the batch sizes in these models scale by the attention mechanism sequence length, leading to large model size and batch sizes. MKOR's complexity is quadratic with respect to the model size, alleviating the computation bottlenecks in second-order methods. Because of their high computation complexity, state-of-the-art implementations of second-order methods can only afford to update the second order information infrequently, and thus do not fully exploit the promise of better convergence from these updates. By reducing the communication complexity of the second-order updates as well as achieving a linear communication complexity, MKOR increases the frequency of second order updates. We also propose a hybrid version of MKOR (called MKOR-H) that mid-training falls backs to a first order optimizer if the second order updates no longer accelerate convergence. Our experiments show that MKOR outperforms state -of-the-art first order methods, e.g. the LAMB optimizer, and best implementations of second-order methods, i.e. KAISA/KFAC, up to 2.57x and 1.85x respectively on BERT-Large-Uncased on 64 GPUs.
翻译:本文提出一种基于秩-1更新的动量增强克罗内克因子优化器(MKOR),旨在提升深度神经网络(DNNs)的训练时间与收敛性能。二阶优化技术虽比一阶方法具有更高收敛速率,但其计算复杂度与模型尺寸和/或训练批次大小呈三次方关系。因此,在Transformer模型(例如大型语言模型LLMs)中,由于批次大小随注意力机制序列长度扩展,导致模型尺寸与批次规模激增,此类方法存在可扩展性差与性能低下的问题。MKOR的复杂度与模型尺寸呈二次方关系,有效缓解了二阶方法中的计算瓶颈。受限于高计算复杂度,当前最先进的二阶方法实现仅能低频更新二阶信息,故未能充分发挥其更优收敛的潜力。通过将二阶更新的通信复杂度降至线性水平,MKOR大幅提高了二阶更新频率。我们进一步提出混合版本MKOR-H,该版本在训练中期若二阶更新不再加速收敛时,会自动回退至一阶优化器。实验表明,在基于64块GPU的BERT-Large-Uncased模型上,MKOR分别以最高2.57倍和1.85倍的性能优势超越最先进的一阶方法(如LAMB优化器)与二阶方法的最佳实现(如KAISA/KFAC)。