Autoregressive large language models (LLMs) have made remarkable progress in various natural language generation tasks. However, they incur high computation cost and latency resulting from the autoregressive token-by-token generation. To address this issue, several approaches have been proposed to reduce computational cost using early-exit strategies. These strategies enable faster text generation using reduced computation without applying the full computation graph to each token. While existing token-level early exit methods show promising results for online inference, they cannot be readily applied for batch inferencing and Key-Value caching. This is because they have to wait until the last token in a batch exits before they can stop computing. This severely limits the practical application of such techniques. In this paper, we propose a simple and effective token-level early exit method, SkipDecode, designed to work seamlessly with batch inferencing and KV caching. It overcomes prior constraints by setting up a singular exit point for every token in a batch at each sequence position. It also guarantees a monotonic decrease in exit points, thereby eliminating the need to recompute KV Caches for preceding tokens. Rather than terminating computation prematurely as in prior works, our approach bypasses lower to middle layers, devoting most of the computational resources to upper layers, allowing later tokens to benefit from the compute expenditure by earlier tokens. Our experimental results show that SkipDecode can obtain 2x to 5x inference speedups with negligible regression across a variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7 billion parameters, all the while being directly compatible with batching and KV caching optimization techniques.
翻译:自回归大语言模型在各种自然语言生成任务中取得了显著进展。然而,由于其自回归逐令牌生成机制,计算成本和延迟较高。为解决这一问题,已有多种方法提出采用早期退出策略来降低计算成本。这些策略通过减少计算量实现更快的文本生成,无需对每个令牌应用完整计算图。尽管现有的令牌级早期退出方法在在线推理中表现出色,但难以直接应用于批处理推理和键值缓存。这是因为它们必须等待批次中最后一个令牌退出后才能停止计算,严重限制了此类技术的实际应用。本文提出一种简单有效的令牌级早期退出方法SkipDecode,旨在无缝兼容批处理推理与KV缓存。该方法通过在每个序列位置为批次内所有令牌设置统一退出点,克服了先前限制。同时保证退出点单调递减,从而无需为前述令牌重新计算KV缓存。与既往工作中提前终止计算不同,我们的方法跳过中低层,将大部分计算资源投入高层,使后续令牌能够受益于前序令牌的计算开销。实验结果表明,SkipDecode在多种任务上可实现2至5倍推理加速,性能退化可忽略不计。该成果基于参数规模为13亿和67亿的OPT模型实现,且可直接兼容批处理与KV缓存优化技术。