Batch inference workloads for causal transformer models frequently process sequences that share common prefixes, such as system prompts, few-shot examples, or shared queries. Standard inference engines treat each sequence independently, redundantly recomputing identical MLP activations for every copy of the shared prefix. We introduce RadixMLP, a technique that exploits the position-wise nature of MLPs, LayerNorms, linear projections, and embeddings to eliminate this redundancy. RadixMLP dynamically maps batches to a prefix trie, gathering shared segments into a compressed representation for position-wise computation and scattering results back only at attention boundaries. RadixMLP is stateless and operates within a single forward pass. In end-to-end serving benchmarks on MS~MARCO v1.1 with Qwen3 models (0.6B to 8B parameters), RadixMLP achieves 1.44-1.59$\times$ speedups in realistic reranking workloads, with up to $5\times$ speedups on synthetic benchmarks with longer shared prefixes. Our code is available at https://github.com/michaelfeil/radix-mlp.
翻译:因果Transformer模型的批推理工作负载经常处理具有公共前缀的序列,例如系统提示、少样本示例或共享查询。标准推理引擎独立处理每个序列,对共享前缀的每个副本冗余地重新计算相同的MLP激活。我们提出RadixMLP技术,利用MLP、LayerNorm、线性投影和嵌入的逐位置特性来消除这种冗余。RadixMLP将批次动态映射到前缀树,将共享片段聚合为压缩表示以进行逐位置计算,并仅在注意力边界处将结果分散回原序列。RadixMLP是无状态的,在单次前向传播中运行。在基于Qwen3模型(0.6B至8B参数)的MS~MARCO v1.1端到端服务基准测试中,RadixMLP在实际重排序任务中实现1.44-1.59$\times$加速,在具有更长共享前缀的合成基准测试中最高可达$5\times$加速。代码发布于https://github.com/michaelfeil/radix-mlp。