In modern machine learning, inner product attention computation is a fundamental task for training large language models such as Transformer, GPT-1, BERT, GPT-2, GPT-3 and ChatGPT. Formally, in this problem, one is given as input three matrices $Q, K, V \in [-B,B]^{n \times d}$, and the goal is to construct the matrix $\mathrm{Att}(Q,K,V) := \mathrm{diag}(A {\bf 1}_n)^{-1} A V \in \mathbb{R}^{n \times d}$, where $A = \exp(QK^\top/d)$ is the `attention matrix', and $\exp$ is applied entry-wise. Straightforward methods for this problem explicitly compute the $n \times n$ attention matrix $A$, and hence require time $\Omega(n^2)$ even when $d = n^{o(1)}$ is small. In this paper, we investigate whether faster algorithms are possible by implicitly making use of the matrix $A$. We present two results, showing that there is a sharp transition at $B = \Theta(\sqrt{\log n})$. $\bullet$ If $d = O(\log n)$ and $B = o(\sqrt{\log n})$, there is an $n^{1+o(1)}$ time algorithm to approximate $\mathrm{Att}(Q,K,V)$ up to $1/\mathrm{poly}(n)$ additive error. $\bullet$ If $d = O(\log n)$ and $B = \Theta (\sqrt{\log n})$, assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory, it is impossible to approximate $\mathrm{Att}(Q,K,V)$ up to $1/\mathrm{poly}(n)$ additive error in truly subquadratic time $n^{2 - \Omega(1)}$. This gives a theoretical explanation for the phenomenon observed in practice that attention computation is much more efficient when the input matrices have smaller entries.
翻译:在现代机器学习中,内积注意力计算是训练大型语言模型(如Transformer、GPT-1、BERT、GPT-2、GPT-3和ChatGPT)的基础任务。形式化地,该问题中给定三个输入矩阵 $Q, K, V \in [-B,B]^{n \times d}$,目标是构造矩阵 $\mathrm{Att}(Q,K,V) := \mathrm{diag}(A {\bf 1}_n)^{-1} A V \in \mathbb{R}^{n \times d}$,其中 $A = \exp(QK^\top/d)$ 为"注意力矩阵",$\exp$ 按元素逐次应用。直接求解此问题的常规方法需显式计算 $n \times n$ 的注意力矩阵 $A$,因此即使当 $d = n^{o(1)}$ 较小时仍需要 $\Omega(n^2)$ 时间。本文研究能否通过隐式利用矩阵 $A$ 来实现更快的算法。我们给出两个结果,表明在 $B = \Theta(\sqrt{\log n})$ 处存在锐利转变:
$\bullet$ 若 $d = O(\log n)$ 且 $B = o(\sqrt{\log n})$,存在 $n^{1+o(1)}$ 时间的算法,以 $1/\mathrm{poly}(n)$ 加法误差近似 $\mathrm{Att}(Q,K,V)$。
$\bullet$ 若 $d = O(\log n)$ 且 $B = \Theta (\sqrt{\log n})$,假设精细复杂度理论中的强指数时间假设,则无法在真正次二次时间 $n^{2 - \Omega(1)}$ 内以 $1/\mathrm{poly}(n)$ 加法误差近似 $\mathrm{Att}(Q,K,V)$。这从理论上解释了实践中观察到的现象:当输入矩阵具有更小条目时,注意力计算会高效得多。