In the classical transformer attention scheme, we are given three $n \times d$ size matrices $Q, K, V$ (the query, key, and value tokens), and the goal is to compute a new $n \times d$ size matrix $D^{-1} \exp(QK^\top) V$ where $D = \mathrm{diag}( \exp(QK^\top) {\bf 1}_n )$. In this work, we study a generalization of attention which captures triple-wise correlations. This generalization is able to solve problems about detecting triple-wise connections that were shown to be impossible for transformers. The potential downside of this generalization is that it appears as though computations are even more difficult, since the straightforward algorithm requires cubic time in $n$. However, we show that in the bounded-entry setting (which arises in practice, and which is well-studied in both theory and practice), there is actually a near-linear time algorithm. More precisely, we show that bounded entries are both necessary and sufficient for quickly performing generalized computations: $\bullet$ On the positive side, if all entries of the input matrices are bounded above by $o(\sqrt[3]{\log n})$ then we show how to approximate the ``tensor-type'' attention matrix in $n^{1+o(1)}$ time. $\bullet$ On the negative side, we show that if the entries of the input matrices may be as large as $\Omega(\sqrt[3]{\log n})$, then there is no algorithm that runs faster than $n^{3-o(1)}$ (assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory). We also show that our construction, algorithms, and lower bounds naturally generalize to higher-order tensors and correlations. Interestingly, the higher the order of the tensors, the lower the bound on the entries needs to be for an efficient algorithm. Our results thus yield a natural tradeoff between the boundedness of the entries, and order of the tensor one may use for more expressive, efficient attention computation.
翻译:在经典Transformer注意力机制中,给定三个$n \times d$维矩阵$Q, K, V$(查询、键和值标记),目标为计算新的$n \times d$维矩阵$D^{-1} \exp(QK^\top) V$,其中$D = \mathrm{diag}( \exp(QK^\top) {\bf 1}_n)$。本研究考虑一种捕获三重相关性的广义注意力机制。该泛化能解决检测三重连接的问题,而此类问题已被证明对Transformer不可解。其潜在缺陷在于计算复杂度显著增加——直接算法的复杂度为$n$的三次方。然而我们证明,在有界条目设定下(实践中常见且被理论与实务广泛研究),存在近线性时间算法。具体而言,我们证明有界性是快速执行广义计算的充分必要条件:
$\bullet$ 正向结论:若输入矩阵所有条目均以$o(\sqrt[3]{\log n})$为上界,则可在$n^{1+o(1)}$时间内近似"张量型"注意力矩阵。
$\bullet$ 反向结论:若输入矩阵条目可达$\Omega(\sqrt[3]{\log n})$,则不存在快于$n^{3-o(1)}$的算法(基于细粒度复杂度理论中的强指数时间假设)。
我们还证明,我们的构造、算法与下界可自然推广至高阶张量与相关性。值得关注的是,张量阶数越高,高效算法对条目上界的要求越严格。因此,我们的结果揭示了条目有界性与可表达、高效注意力计算所需张量阶数之间的自然权衡。