Despite their power, Transformers face challenges with long sequences due to the quadratic complexity of self-attention. To address this limitation, methods like $k$-Nearest-Neighbor ($k$NN) attention have been introduced [Roy, Saffar, Vaswani, Grangier, 2021] enabling each token to attend to only its $k$ closest tokens. While $k$NN attention has shown empirical success in making Transformers more efficient, its exact approximation guarantees have not been theoretically analyzed. In this work, we establish a theoretical framework for $k$NN attention, reformulating self-attention as expectations over softmax distributions and leveraging lazy Gumbel sampling [Mussmann, Levy, Ermon, 2017] with $k$NN indices for efficient approximation. Building on this framework, we also propose novel sub-quadratic algorithms that approximate self-attention gradients by leveraging efficient sampling techniques, such as Markov Chain-based estimation. Finally, we demonstrate the practical effectiveness of these algorithms through empirical experiments, showcasing their benefits in both training and inference.
翻译:尽管Transformer模型功能强大,但由于自注意力机制具有二次复杂度,其在处理长序列时面临挑战。为突破这一限制,诸如k最近邻(kNN)注意力等方法被提出[Roy, Saffar, Vaswani, Grangier, 2021],使每个标记仅能关注与其最接近的k个标记。虽然kNN注意力在提升Transformer效率方面已展现出实证效果,但其精确近似保证尚未得到理论分析。本研究建立了kNN注意力的理论框架,将自注意力重新表述为基于softmax分布的期望计算,并利用惰性Gumbel采样[Mussmann, Levy, Ermon, 2017]结合kNN索引实现高效近似。基于该框架,我们进一步提出新颖的次二次复杂度算法,通过高效采样技术(如基于马尔可夫链的估计)来近似自注意力梯度。最后,我们通过实证实验验证了这些算法的实际有效性,展示了其在训练和推理阶段的优势。