Cross Attention is a popular method for retrieving information from a set of context tokens for making predictions. At inference time, for each prediction, Cross Attention scans the full set of $\mathcal{O}(N)$ tokens. In practice, however, often only a small subset of tokens are required for good performance. Methods such as Perceiver IO are cheap at inference as they distill the information to a smaller-sized set of latent tokens $L < N$ on which cross attention is then applied, resulting in only $\mathcal{O}(L)$ complexity. However, in practice, as the number of input tokens and the amount of information to distill increases, the number of latent tokens needed also increases significantly. In this work, we propose Tree Cross Attention (TCA) - a module based on Cross Attention that only retrieves information from a logarithmic $\mathcal{O}(\log(N))$ number of tokens for performing inference. TCA organizes the data in a tree structure and performs a tree search at inference time to retrieve the relevant tokens for prediction. Leveraging TCA, we introduce ReTreever, a flexible architecture for token-efficient inference. We show empirically that Tree Cross Attention (TCA) performs comparable to Cross Attention across various classification and uncertainty regression tasks while being significantly more token-efficient. Furthermore, we compare ReTreever against Perceiver IO, showing significant gains while using the same number of tokens for inference.
翻译:交叉注意力是一种从上下文标记集合中检索信息以进行预测的流行方法。在推理阶段,每次预测时,交叉注意力会扫描全部 $\mathcal{O}(N)$ 个标记。然而实际应用中,通常仅需少量标记子集即可获得良好性能。Perceiver IO 等方法通过将信息蒸馏至更小的潜变量标记集合 $L < N$(随后对其应用交叉注意力),将推理复杂度降低至 $\mathcal{O}(L)$,从而在推理时具有较低开销。但实践中,随着输入标记数量及需蒸馏信息的增加,所需的潜变量标记数量也会显著增长。本文提出树形交叉注意力(Tree Cross Attention, TCA)——一种基于交叉注意力的模块,推理时仅需从对数级别 $\mathcal{O}(\log(N))$ 的标记中检索信息。TCA 将数据组织为树形结构,并在推理阶段执行树搜索以检索相关标记进行预测。基于 TCA,我们进一步推出 ReTreever——一种面向标记高效推理的灵活架构。实验表明,树形交叉注意力(TCA)在各类分类与不确定性回归任务中性能与标准交叉注意力相当,同时显著提升标记效率。此外,与 Perceiver IO 相比,在使用相同推理标记数量时,ReTreever 展现出明显优势。