State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical $O(1)$ state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ($15%$ MFU) and up to $64%$ bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.


翻译:状态空间模型的发布通常与融合的CUDA和Triton内核绑定,从而对NVIDIA硬件形成了硬性依赖。我们证明,Mamba-2的状态空间对偶算法——对角状态结构、可分块的递归、以einsum为主且具有静态控制流的计算——能够清晰地映射到XLA的融合与分块优化过程所实际优化的目标上,这使得自定义内核成为可选而非必需。我们在XLA框架下,将完整的推理路径(预填充、带缓存的自动回归解码)实现为具有确定形状的标准原语,无需手写内核,并将该架构理论上的$O(1)$状态管理实现为一个编译后的设备端缓存,该缓存在生成过程中无需主机同步。该实现无需修改即可在CPU、NVIDIA GPU和Google Cloud TPU上运行,且源自单一的JAX代码库。在TPU v6e上,针对五种模型规模(1.3亿至27亿参数),XLA生成的代码在单流预填充上达到约140 TFLOPS(15% MFU),在解码阶段带宽利用率最高可达64%。贪婪解码在64个步骤中与PyTorch/CUDA参考实现实现了逐令牌匹配,隐藏状态的一致性在float32舍入误差容限内。该模式可迁移至任何满足相同结构条件的SSM递归模型,以及任何拥有成熟XLA后端的平台。该实现已在https://github.com/CosmoNaught/mamba2-jax 公开,并已并入Bonsai JAX模型库。

0
下载
关闭预览

相关内容

TransMLA:多头潜在注意力(MLA)即为所需
专知会员服务
23+阅读 · 2025年2月13日
【NeurIPS2023】基于语义对齐的潜空间翻译
专知会员服务
21+阅读 · 2023年11月2日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
VIP会员
最新内容
国外海军作战管理系统与作战训练系统
专知会员服务
0+阅读 · 今天4:16
美军条令《海军陆战队规划流程(2026版)》
专知会员服务
4+阅读 · 今天3:36
《压缩式分布式交互仿真标准》120页
专知会员服务
3+阅读 · 今天3:21
《电子战数据交换模型研究报告》
专知会员服务
4+阅读 · 今天3:13
《基于Transformer的异常舰船导航识别与跟踪》80页
《低数据领域军事目标检测模型研究》
专知会员服务
4+阅读 · 今天2:37
【CMU博士论文】物理世界的视觉感知与深度理解
伊朗战争停火期间美军关键弹药状况分析
专知会员服务
8+阅读 · 4月22日
电子战革命:塑造战场的十年突破(2015–2025)
相关VIP内容
TransMLA:多头潜在注意力(MLA)即为所需
专知会员服务
23+阅读 · 2025年2月13日
【NeurIPS2023】基于语义对齐的潜空间翻译
专知会员服务
21+阅读 · 2023年11月2日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员