Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
翻译:矩阵乘法性能长期以来一直是扩展深度学习工作负载的主要瓶颈,这促使了使用日益低精度数值格式的新型加速器设计。然而,矩阵乘法性能的提升速度远超归约运算和逐元素计算的性能提升,后者仍在使用更高精度进行计算。在本工作中,我们提出MXNorm,作为RMSNorm的即插即用替代方案,它仅利用MXFP8转换过程中计算出的块尺度来估计均方根值,从而将归一化所需的归约运算规模减小32倍。我们在参数规模为125M、1B和8B的Llama 3模型预训练中验证了我们的近似方法,发现与使用MXFP8矩阵乘法的RMSNorm基线相比,训练精度损失极小。我们还展示了仅使用torch.compile实现的实用内核加速,MXNorm相比RMSNorm最高可达2.4倍,这对应于MXFP8精度下Llama 3 8B Transformer层的1.3%加速,以及在NVFP4精度下的2.6%加速。