Adaptive optimizers such as Adam (Kingma & Ba, 2015) have been central to the success of large language models. However, they often require to maintain optimizer states throughout training, which can result in memory requirements several times greater than the model footprint. This overhead imposes constraints on scalability and computational efficiency. Stochastic Gradient Descent (SGD), in contrast, is a stateless optimizer, as it does not track state variables during training. Consequently, it achieves optimal memory efficiency. However, its capability in LLM training is limited (Zhao et al., 2024b). In this work, we show that pre-processing SGD in a stateless manner can achieve the same performance as the Adam optimizer for LLM training, while drastically reducing the memory cost. Specifically, we propose to pre-process the instantaneous stochastic gradients using normalization and whitening. We show that normalization stabilizes gradient distributions, and whitening counteracts the local curvature of the loss landscape. This results in SWAN (SGD with Whitening And Normalization), a stochastic optimizer that eliminates the need to store any optimizer states. Empirically, SWAN has the same memory footprint as SGD, achieving $\approx 50\%$ reduction on total end-to-end memory compared to Adam. In language modeling tasks, SWAN demonstrates comparable or even better performance than Adam: when pre-training the LLaMA model with 350M and 1.3B parameters, SWAN achieves a 2x speedup by reaching the same evaluation perplexity using half as many tokens.
翻译:自适应优化器(如Adam (Kingma & Ba, 2015))一直是大型语言模型成功的关键。然而,这类优化器通常需要在训练全程维护优化器状态,导致内存开销可达模型本身占用的数倍。这种额外负担限制了模型训练的扩展性与计算效率。相比之下,随机梯度下降(SGD)作为一种无状态优化器,在训练过程中无需追踪状态变量,因而实现了最优的内存效率。但其在LLM训练中的性能表现有限(Zhao et al., 2024b)。本研究表明,通过无状态方式对SGD进行梯度预处理,可在LLM训练中达到与Adam优化器相当的性能,同时大幅降低内存成本。具体而言,我们提出采用归一化与白化技术对瞬时随机梯度进行预处理。研究证明,归一化可稳定梯度分布,而白化能抵消损失函数局部曲率的影响。由此我们提出了SWAN(白化与归一化随机梯度下降法),这是一种无需存储任何优化器状态的随机优化器。实验表明,SWAN具有与SGD相同的内存占用,相比Adam实现了端到端总内存约50%的降低。在语言建模任务中,SWAN表现出与Adam相当甚至更优的性能:使用350M和1.3B参数的LLaMA模型进行预训练时,SWAN仅需一半的token数量即可达到相同的评估困惑度,实现了2倍的训练加速。