The transformer architecture has driven breakthroughs in recent years on tasks which require modeling pairwise relationships between sequential elements, as is the case in natural language understanding. However, long seqeuences pose a problem due to the quadratic complexity of the attention operation. Previous research has aimed to lower the complexity by sparsifying or linearly approximating the attention matrix. Yet, these approaches cannot straightforwardly distill knowledge from a teacher's attention matrix and often require complete retraining from scratch. Furthermore, previous sparse and linear approaches lose interpretability if they cannot produce full attention matrices. To address these challenges, we propose SEA: Sparse linear attention with an Estimated Attention mask. SEA estimates the attention matrix with linear complexity via kernel-based linear attention, then subsequently creates a sparse attention matrix with a top-k selection to perform a sparse attention operation. For language modeling tasks (Wikitext2), previous linear and sparse attention methods show roughly two-fold worse perplexity scores over the quadratic OPT-1.3B baseline, while SEA achieves better perplexity than OPT-1.3B, using roughly half the memory of OPT-1.3B, providing interpretable attention matrix. We believe that our work will have a large practical impact, as it opens the possibility of running large transformers on resource-limited devices with less memory.
翻译:Transformer架构近年来在需要建模序列元素间成对关系的任务(如自然语言理解)中取得了突破性进展。然而,由于注意力操作的二次复杂度,长序列成为棘手问题。先前研究旨在通过稀疏化或线性逼近注意力矩阵来降低复杂度,但这些方法无法直接提取教师模型的注意力知识,通常需要完全从头训练。此外,若无法生成完整注意力矩阵,先前的稀疏和线性方法会丧失可解释性。针对这些挑战,我们提出SEA:基于估计注意力掩码的稀疏线性注意力。SEA通过核函数线性注意力以线性复杂度估计注意力矩阵,随后通过top-k选择构建稀疏注意力矩阵执行稀疏注意力操作。在语言建模任务(Wikitext2)中,先前的线性与稀疏注意力方法相较于二次复杂度的OPT-1.3B基线,困惑度得分约差两倍;而SEA在仅使用OPT-1.3B约一半内存的情况下,取得了更优的困惑度表现,同时提供可解释的注意力矩阵。我们相信这项工作将产生重大实践影响,因其开创了在资源受限设备上以更低内存运行大型Transformer的可能性。