We introduce the Block Transformer which adopts hierarchical global-to-local modeling to autoregressive transformers to mitigate the inference bottlenecks associated with self-attention. Self-attention requires the key-value (KV) cache of all previous sequences to be retrieved from memory at every decoding step to retrieve context information, leading to two primary bottlenecks during batch inference. First, there is a significant delay in obtaining the first token, as the information of the entire prompt must first be processed to prefill the KV cache. Second, computation of subsequent tokens is bottlenecked by the high memory I/O demand of fetching the entire KV cache, which grows linearly with sequence length, incurring quadratic memory reads overall. We design the Block Transformer to strategically mitigate these costs, by incorporating coarsity and locality into an integrated global-to-local architecture. At the lower layers, we aggregate tokens into fixed size blocks to apply attention across the entire sequence at coarse-grained detail, to capture the global context while minimizing KV cache overhead. At upper layers, we apply attention within each block to decode individual tokens, to model fine-grained details with a lightweight local KV cache. We pretrain vanilla and Block Transformers from scratch and demonstrate that Block Transformers reach 10--20x inference throughput compared to vanilla transformers with equivalent perplexity and zero-shot task performance. Code is available at https://github.com/itsnamgyu/block-transformer.
翻译:我们提出了块状Transformer,它采用分层式全局到局部建模方法改进自回归Transformer,以缓解自注意力机制带来的推理瓶颈。自注意力机制要求在每个解码步骤从内存中检索所有先前序列的键值(KV)缓存以获取上下文信息,这在批量推理过程中会导致两个主要瓶颈:首先,由于需要先处理整个提示信息以预填充KV缓存,获取首个令牌存在显著延迟;其次,后续令牌的计算受限于获取整个KV缓存的高内存I/O需求,该缓存随序列长度线性增长,整体产生二次方级的内存读取开销。我们设计的块状Transformer通过将稀疏性与局部性整合到全局到局部架构中,战略性降低了这些开销。在底层网络,我们将令牌聚合为固定大小的块,以粗粒度方式对整个序列进行注意力计算,在捕获全局上下文的同时最小化KV缓存开销;在高层网络,我们在每个块内部进行注意力计算以解码单个令牌,通过轻量级局部KV缓存建模细粒度细节。我们从头开始预训练标准Transformer和块状Transformer,实验表明在保持同等困惑度和零样本任务性能的前提下,块状Transformer的推理吞吐量达到标准Transformer的10-20倍。代码发布于https://github.com/itsnamgyu/block-transformer。