Maximum inner product search (MIPS) is a crucial subroutine in machine learning, requiring the identification of key vectors that align best with a given query. We propose amortized MIPS: a learning-based approach that trains neural networks to directly predict MIPS solutions, amortizing the computational cost of matching queries (drawn from a fixed distribution) to a fixed set of keys. Our key insight is that the MIPS value function, the maximal inner product between a query and keys, is also known as the support function of the set of keys. Support functions are convex, 1-homogeneous and their gradient w.r.t. the query is exactly the optimal key in the database. We approximate the support function using two complementary approaches: (1) we train an input-convex neural network (SupportNet) to model the support function directly; the optimal key can be recovered via (autodiff) gradient computation, and (2) we regress directly the optimal key from the query using a vector valued network (KeyNet), bypassing gradient computation entirely at inference time. To learn a SupportNet, we combine score regression with gradient matching losses, and propose homogenization wrappers that enforce the positive 1-homogeneity of a neural network, theoretically linking function values to gradients. To train a KeyNet, we introduce a score consistency loss derived from the Euler theorem for homogeneous functions. Our experiments show that learned SupportNet or KeyNet achieve high match rates and open up new directions to compress databases with a specific query distribution in mind.
翻译:最大内积搜索(MIPS)是机器学习中的关键子程序,其目标是从给定键向量集中找出与查询向量内积最大的关键向量。本文提出摊销式MIPS:一种基于学习的方法,通过训练神经网络直接预测MIPS解,从而将查询(从固定分布中抽取)与固定键集匹配的计算成本进行摊销。我们的核心洞见在于:MIPS值函数——即查询与键集之间的最大内积——在数学上等价于键集的支持函数。支持函数具有凸性、正一次齐次性,且其对查询的梯度恰好对应数据库中的最优键向量。我们通过两种互补方法逼近支持函数:(1)训练输入凸神经网络(SupportNet)直接建模支持函数,最优键可通过自动微分梯度计算恢复;(2)使用向量值网络(KeyNet)直接从查询回归最优键,在推理阶段完全绕过梯度计算。为训练SupportNet,我们将分数回归与梯度匹配损失相结合,并提出齐次化封装器以强制神经网络的正一次齐次性,从而在理论上建立函数值与梯度的关联。对于KeyNet的训练,我们基于齐次函数的欧拉定理推导出分数一致性损失。实验表明,学习得到的SupportNet与KeyNet均能实现高匹配率,并为针对特定查询分布的数据库压缩开辟了新方向。