State-of-the-art Learned Sparse Retrieval (LSR) models, such as Splade, typically employ a Language Modeling (LM) head to project latent hidden states into a lexically-anchored logit matrix. This intermediate matrix is subsequently transformed into a sparse lexical representation through element-wise operations (ReLU, Log1P) and max-pooling over the sequence dimension. Despite its effectiveness, the LM head creates a massive memory bottleneck due to the sheer size of the vocabulary (V), which can range from 30,000 to over 250,000 tokens in recent models. Materializing this matrix creates a significant memory bottleneck, limiting model scaling. The resulting I/O overhead between operators further throttles throughput and runtime performance. In this paper, we propose Sparton, a fast memory-efficient Triton kernel tailored for the LM head in LSR models. Sparton utilizes a fused approach that integrates the tiled matrix multiplication, ReLU, Log1P, and max-reduction into a single GPU kernel. By performing an early online reduction directly on raw logit tiles, Sparton avoids materializing the full logit matrix in memory. Our experiments demonstrate that the Sparton kernel, in isolation, achieves up to a 4.8x speedup and an order-of-magnitude reduction in peak memory usage compared to PyTorch baselines. Integrated into Splade (|V| ~ 30k), Sparton enables a 33% larger batch size and 14% faster training with no effectiveness loss. On a multilingual backbone (|V| ~ 250k), these gains jump to a 26x larger batch size and 2.5x faster training.
翻译:最先进的学习型稀疏检索(LSR)模型(如Splade)通常采用语言建模(LM)头,将潜在隐状态投影为基于词汇的logit矩阵。该中间矩阵随后通过逐元素操作(ReLU、Log1P)和序列维度的最大池化,转换为稀疏词汇表示。尽管效果显著,但LM头因词汇量(V)庞大(在近期模型中可达3万至25万以上词元)而造成了巨大的内存瓶颈。实例化该矩阵会产生显著的内存瓶颈,限制模型扩展。算子间的I/O开销进一步制约了吞吐量和运行时性能。本文提出Sparton——一种专为LSR模型LM头设计的快速内存高效Triton内核。Sparton采用融合方法,将平铺矩阵乘法、ReLU、Log1P和最大规约整合至单个GPU内核中。通过直接在原始logit平铺上进行早期在线规约,Sparton避免了在内存中实例化完整logit矩阵。实验表明,Sparton内核单独使用时相比PyTorch基线可实现高达4.8倍的加速和峰值内存使用量降低一个数量级。集成至Splade(|V| ~ 30k)后,Sparton在无效果损失的情况下支持增大33%的批大小和加快14%的训练速度。而在多语言主干(|V| ~ 250k)上,这些增益跃升至26倍的批大小和2.5倍的训练速度。