State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical $O(1)$ state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ($15%$ MFU) and up to $64%$ bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
翻译:状态空间模型的发布通常与融合的CUDA和Triton内核绑定,从而对NVIDIA硬件形成了硬性依赖。我们证明,Mamba-2的状态空间对偶算法——对角状态结构、可分块的递归、以einsum为主且具有静态控制流的计算——能够清晰地映射到XLA的融合与分块优化过程所实际优化的目标上,这使得自定义内核成为可选而非必需。我们在XLA框架下,将完整的推理路径(预填充、带缓存的自动回归解码)实现为具有确定形状的标准原语,无需手写内核,并将该架构理论上的$O(1)$状态管理实现为一个编译后的设备端缓存,该缓存在生成过程中无需主机同步。该实现无需修改即可在CPU、NVIDIA GPU和Google Cloud TPU上运行,且源自单一的JAX代码库。在TPU v6e上,针对五种模型规模(1.3亿至27亿参数),XLA生成的代码在单流预填充上达到约140 TFLOPS(15% MFU),在解码阶段带宽利用率最高可达64%。贪婪解码在64个步骤中与PyTorch/CUDA参考实现实现了逐令牌匹配,隐藏状态的一致性在float32舍入误差容限内。该模式可迁移至任何满足相同结构条件的SSM递归模型,以及任何拥有成熟XLA后端的平台。该实现已在https://github.com/CosmoNaught/mamba2-jax 公开,并已并入Bonsai JAX模型库。