Large models training is plagued by the intense compute cost and limited hardware memory. A practical solution is low-precision representation but is troubled by loss in numerical accuracy and unstable training rendering the model less useful. We argue that low-precision floating points can perform well provided the error is properly compensated at the critical locations in the training process. We propose Collage which utilizes multi-component float representation in low-precision to accurately perform operations with numerical errors accounted. To understand the impact of imprecision to training, we propose a simple and novel metric which tracks the lost information during training as well as differentiates various precision strategies. Our method works with commonly used low-precision such as half-precision ($16$-bit floating points) and can be naturally extended to work with even lower precision such as $8$-bit. Experimental results show that pre-training using Collage removes the requirement of using $32$-bit floating-point copies of the model and attains similar/better training performance compared to $(16, 32)$-bit mixed-precision strategy, with up to $3.7\times$ speedup and $\sim 15\%$ to $23\%$ less memory usage in practice.
翻译:大模型训练面临计算成本高昂与硬件内存有限的严峻挑战。低精度表示是一种实用解决方案,但存在数值精度损失与训练不稳定的问题,导致模型效能降低。我们论证:若能在训练关键环节对误差进行适当补偿,低精度浮点数亦可表现优良。为此提出Collage方法,通过多分量浮点表示在低精度条件下实现数值误差可控的精确运算。为量化精度损失对训练的影响,我们设计了一种简洁新颖的度量指标,既可追踪训练过程中的信息丢失,亦可区分不同精度策略的差异。本方法兼容半精度(16位浮点数)等常用低精度格式,并能自然扩展至8位等更低精度场景。实验结果表明,采用Collage进行预训练无需使用32位浮点数模型副本,相比(16,32)位混合精度策略可获得相近或更优的训练性能,实际应用中可实现最高3.7倍加速与约15%至23%的内存节约。