Late-interaction retrieval (ColBERT, ColPali) scores a query against a document with the MaxSim operator: for every query token, the maximum similarity over the document tokens, summed over query tokens. The standard implementation materializes the full query-token x document-token similarity tensor in GPU memory; for visual ColPali at 10K documents this tensor alone is 21 GB in FP16, created only to be reduced to one score per document and discarded. It exhausts a 40 GB GPU and bounds the achievable batch size in both inference and training. We present Flash-MaxSim, an IO-aware fused GPU kernel that computes exactly the same scores without ever materializing the tensor, by streaming query and document tiles through on-chip SRAM and folding the row-maximum reduction into the same pass. We extend the IO-aware principle through the training backward pass, an inverse-grid CSR construction that reuses the forward argmax for an atomic-free, destination-owned gradient reduction, and through INT8xINT8 quantization and variable-length (padding-free) scoring. Flash-MaxSim is up to 3.9x faster on an A100 (4.7x on an H100) than naive PyTorch at matched precision, uses up to 16x less inference memory and ~28x less training memory, unlocks corpus and batch sizes that exhaust PyTorch entirely, preserves the exact ranking (100% top-20 agreement with an FP32 reference)
翻译:延迟交互检索(ColBERT、ColPali)通过MaxSim算子对查询与文档进行评分:对每个查询词元,计算其与所有文档词元的最大相似度,并将所有查询词元的该最大值求和。标准实现会在GPU内存中显式构建完整的查询词元×文档词元相似度张量;以视觉模型ColPali处理1万篇文档为例,该张量在FP16精度下即占用21 GB内存,其仅用于计算每个文档的一个评分后即被丢弃。这不仅耗尽40 GB GPU内存,还限制了推理和训练中的可行批处理规模。我们提出Flash-MaxSim——一种面向IO感知的融合GPU核函数,通过将查询与文档分块流式传输至片上SRAM,并在同一遍处理中融合行最大值约简操作,实现无需显式张量即精确计算相同评分。我们将IO感知原理拓展至训练反向传播过程:通过构建逆网格压缩稀疏行结构复用前向argmax,实现无原子操作且目标导向的梯度约简;同时支持INT8×INT8量化与变长(无填充)评分。在相同精度下,Flash-MaxSim在A100上比朴素PyTorch实现快3.9倍(H100上快4.7倍),推理内存占用降低16倍,训练内存占用降低约28倍,可处理PyTorch完全无法容纳的语料库与批规模,并保持精确排序(与FP32参考结果在top-20指标上100%一致)。