Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify
翻译:为提升大语言模型训练与推理的计算效率,机器学习加速硬件已引入float8等低精度格式。然而,由于需要复杂且有时脆弱的技术才能达到高精度训练的准确度,机器学习社区对此类格式的采用进展缓慢。本研究提出Scalify——一种面向计算图的端到端尺度传播范式,该范式对现有张量缩放方法进行了泛化与形式化。实验结果表明,Scalify能够开箱即用地支持float8矩阵乘法与梯度表示,以及float16优化器状态存储。我们在https://github.com/graphcore-research/jax-scalify开源了Scalify的JAX实现。