Adaptive optimizers such as Adam (Kingma & Ba, 2015) have been central to the success of large language models. However, they often require to maintain optimizer states throughout training, which can result in memory requirements several times greater than the model footprint. This overhead imposes constraints on scalability and computational efficiency. Stochastic Gradient Descent (SGD), in contrast, is a stateless optimizer, as it does not track state variables during training. Consequently, it achieves optimal memory efficiency. However, its capability in LLM training is limited (Zhao et al., 2024b). In this work, we show that pre-processing SGD in a stateless manner can achieve the same performance as the Adam optimizer for LLM training, while drastically reducing the memory cost. Specifically, we propose to pre-process the instantaneous stochastic gradients using normalization and whitening. We show that normalization stabilizes gradient distributions, and whitening counteracts the local curvature of the loss landscape. This results in SWAN (SGD with Whitening And Normalization), a stochastic optimizer that eliminates the need to store any optimizer states. Empirically, SWAN has the same memory footprint as SGD, achieving $\approx 50\%$ reduction on total end-to-end memory compared to Adam. In language modeling tasks, SWAN demonstrates comparable or even better performance than Adam: when pre-training the LLaMA model with 350M and 1.3B parameters, SWAN achieves a 2x speedup by reaching the same evaluation perplexity using half as many tokens.
翻译:自适应优化器(如Adam (Kingma & Ba, 2015))对大语言模型的成功至关重要。然而,这类优化器通常需要在训练全程维护优化器状态,导致内存需求可能达到模型本身占用的数倍。这种开销限制了模型训练的扩展性与计算效率。相比之下,随机梯度下降(SGD)作为一种无状态优化器,在训练过程中无需追踪状态变量,因而实现了最优的内存效率。但其在大语言模型训练中的性能存在局限(Zhao et al., 2024b)。本研究表明,通过对SGD进行无状态预处理,可在保持与Adam优化器相当性能的同时,显著降低内存成本。具体而言,我们提出采用归一化与白化技术对瞬时随机梯度进行预处理。我们证明归一化能稳定梯度分布,而白化可抵消损失函数局部曲率的影响。由此得到SWAN(基于白化与归一化的SGD)——一种无需存储任何优化器状态的随机优化器。实验表明,SWAN具有与SGD相同的内存占用,相比Adam实现了端到端总内存约50%的降低。在语言建模任务中,SWAN表现出与Adam相当甚至更优的性能:当预训练参数量为350M和1.3B的LLaMA模型时,SWAN仅需一半的token数量即可达到相同的评估困惑度,实现了2倍的训练加速。