Despite significant work on low-bit quantization-aware training (QAT), there is still a large accuracy gap between such techniques and native training. To address this, we introduce CAGE (Curvature-Aware Gradient Estimation), a new QAT method that augments the straight-through estimator (STE) gradient with a curvature-aware correction designed to counteract the loss increase induced by quantization. CAGE is derived from a multi-objective view of QAT that balances loss minimization with adherence to quantization constraints, yielding a principled correction term that depends on local curvature information. On the theoretical side, we introduce the notion of Pareto-optimal solutions for quantized optimization, and establish that CAGE yields strong convergence guarantees in the smooth non-convex setting. In terms of implementation, our approach is optimizer-agnostic, but we provide a highly-efficient implementation that leverages Adam statistics. When pre-training Llama-style models of up to 800M-parameters, CAGE recovers over 10% of the quantization-induced loss increase in the W4A4 regime over outlier-mitigation methods. These results indicate that curvature-aware gradient corrections can bridge the remaining performance gap beyond current outlier-handling methods.
翻译:尽管在低比特量化感知训练(QAT)方面已有大量研究,此类技术与原生训练之间仍存在显著的精度差距。为解决这一问题,我们提出了CAGE(曲率感知梯度估计),这是一种新的QAT方法,它通过一个旨在抵消量化引起的损失增加的曲率感知校正项来增强直通估计器(STE)梯度。CAGE源于对QAT的多目标视角,该视角平衡了损失最小化与量化约束的遵循,从而产生一个依赖于局部曲率信息的原则性校正项。在理论方面,我们引入了量化优化的帕累托最优解概念,并证明CAGE在平滑非凸设定下具有强收敛保证。在实现层面,我们的方法对优化器无关,但我们提供了一个利用Adam统计量的高效实现。在对参数规模高达8亿的Llama风格模型进行预训练时,在W4A4量化方案下,CAGE相较于离群值缓解方法,能够恢复超过10%由量化引起的损失增加。这些结果表明,曲率感知梯度校正能够弥合当前离群值处理方法之外的剩余性能差距。