Query-key (QK) normalization stabilizes attention by controlling the scale of queries and keys before the dot product, but is not immediately compatible with Multi-head Latent Attention (MLA). MLA achieves efficient decoding by caching low-dimensional latent states instead of full keys, whereas post-projection QK RMSNorm appears to require the fully projected key for every cached token. We show this apparent incompatibility is an implementation artifact, not an architectural constraint. RMSNorm decomposes into a static affine weight and a dynamic scalar RMS statistic. The static key-side weight can be absorbed into the MLA query-side projection; the dynamic key statistic reduces to one inverse-RMS scalar per token and KV group. The resulting formulation is exactly equivalent to explicit post-projection QK RMSNorm in exact arithmetic and preserves MLA's latent decode path. In our 400M runs trained for up to 100B tokens, QK-Normed MLA achieves lower training loss and better downstream accuracy than QK clipping, while H800 decode benchmarks show less than 2% latency overhead up to 256k context. These results make QK normalization a practical stabilization option for MLA models without requiring full-key caching.
翻译:查询-键(QK)归一化通过控制点积前查询与键的尺度来稳定注意力机制,但该技术无法直接兼容多头潜在注意力(MLA)。MLA通过缓存低维潜在状态而非完整键向量实现高效解码,而后投影QK RMSNorm似乎需要为每个缓存令牌存储完整投影后的键。我们证明这种表面上的不兼容性源于实现伪影而非架构约束。RMSNorm可分解为静态仿射权重与动态标量RMS统计量:键侧静态权重可吸收至MLA的查询侧投影中,而键动态统计量则简化为每个令牌及KV组对应的逆RMS标量。该推导公式在精确算术运算下与显式后投影QK RMSNorm完全等价,并保留了MLA的潜在解码路径。在训练量高达1000亿令牌的4亿参数模型中,QK归一化MLA相比QK裁剪方法实现了更低的训练损失与更优的下游准确率,而H800解码基准测试显示在最长256k上下文长度下延迟开销不足2%。这些结果表明QK归一化可为MLA模型提供无需全键缓存的实用稳定方案。