Standard inference and training with transformer based architectures scale quadratically with input sequence length. This is prohibitively large for a variety of applications especially in web-page translation, query-answering etc. Consequently, several approaches have been developed recently to speedup attention computation by enforcing different attention structures such as sparsity, low-rank, approximating attention using kernels. In this work, we view attention computation as that of nearest neighbor retrieval, and use decision tree based hierarchical navigation to reduce the retrieval cost per query token from linear in sequence length to nearly logarithmic. Based on such hierarchical navigation, we design Treeformer which can use one of two efficient attention layers -- TF-Attention and TC-Attention. TF-Attention computes the attention in a fine-grained style, while TC-Attention is a coarse attention layer which also ensures that the gradients are "dense". To optimize such challenging discrete layers, we propose a two-level bootstrapped training method. Using extensive experiments on standard NLP benchmarks, especially for long-sequences, we demonstrate that our Treeformer architecture can be almost as accurate as baseline Transformer while using 30x lesser FLOPs in the attention layer. Compared to Linformer, the accuracy can be as much as 12% higher while using similar FLOPs in the attention layer.
翻译:基于Transformer架构的标准推理与训练的计算复杂度随输入序列长度呈二次方增长。这在网页翻译、问答等多种应用中显得过于庞大。为此,近期涌现出若干通过强制注意力结构(如稀疏性、低秩性、利用核函数近似注意力)来加速注意力计算的方法。本研究将注意力计算视为最近邻检索问题,并采用基于决策树的层级导航,将每个查询令牌的检索成本从序列长度的线性关系降低至近对数关系。基于这种层级导航,我们设计了Treeformer,可采用两种高效注意力层之一——TF-Attention和TC-Attention。TF-Attention以细粒度方式计算注意力,而TC-Attention作为粗粒度注意力层,同时能确保梯度具有“密集性”。针对这类具有挑战性的离散层优化,我们提出了一种两级自举训练方法。通过在标准NLP基准测试(尤其针对长序列)上的大量实验,我们证明Treeformer架构在注意力层仅使用30倍更少FLOPs的条件下,其准确率几乎与基线Transformer相当。与Linformer相比,在注意力层使用相似FLOPs时,准确率最高可提升12%。