As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.
翻译:随着语言模型规模不断扩大,其词汇表也日益增长。这使得大语言模型在训练期间的内存占用不成比例地集中到单一层:损失计算中的交叉熵。交叉熵构建了一个对数矩阵,其条目对应每个输入标记与词汇项的组合;对于小型模型,该操作消耗的内存比大语言模型其余部分总和还要高出一个数量级。我们提出"削减交叉熵"方法,该方法无需将所有标记的对数张量实例化到全局内存中即可计算交叉熵损失。具体而言,CCE仅计算正确标记的对数,并动态评估所有对数的对数求和指数运算。我们实现了一个定制化内核,可在闪存中执行矩阵乘法及词汇表上的对数求和指数归约操作,使得交叉熵计算的全局内存消耗可忽略不计。这产生了显著效果:以Gemma 2(20亿参数)模型为例,CCE将损失计算的内存占用从24GB降至1MB,并将分类器头在训练期间的总内存消耗从28GB降至1GB。为提升CCE的吞吐效率,我们利用softmax固有的稀疏性,提出跳过对梯度贡献可忽略(即低于数值精度)的梯度计算元素。实验表明,这种内存消耗的显著降低并未牺牲训练速度或收敛性能。