Auto-regressive Large Language Models (LLMs) demonstrate remarkable performance across different domains such as vision and language processing. However, due to sequential processing through a stack of transformer layers, autoregressive decoding faces significant computation/latency challenges, particularly in resource-constrained environments like mobile and edge devices. Existing approaches in literature that aim to improve latency via skipping layers have two distinct flavors - 1) Early exit, and 2) Input-agnostic heuristics where tokens exit at pre-determined layers irrespective of input sequence. Both the above strategies have limitations - the former cannot be applied to handle KV Caching necessary for speed-ups in modern framework and the latter does not capture the variation in layer importance across tasks or more generally, across input sequences. To address both limitations, we propose FiRST, an algorithm that reduces inference latency by using layer-specific routers to select a subset of transformer layers adaptively for each input sequence - the prompt (during the prefill stage) decides which layers will be skipped during decoding. FiRST preserves compatibility with KV caching enabling faster inference while being quality-aware. FiRST is model-agnostic and can be easily enabled on any pre-trained LLM. Our approach reveals that input adaptivity is critical - indeed, different task-specific middle layers play a crucial role in evolving hidden representations depending on tasks. Extensive experiments show that FiRST significantly reduces latency while outperforming other layer selection strategies in quality metics. It retains competitive performance to base model (without layer skipping) and in some cases, even improves upon it. FiRST is thus a promising and efficient solution for LLM deployment in low-resource environments.
翻译:自回归大语言模型(LLM)在视觉与语言处理等多个领域展现出卓越性能。然而,由于需通过堆叠的Transformer层进行序列化处理,自回归解码面临显著的计算/延迟挑战,在移动设备与边缘设备等资源受限环境中尤为突出。现有文献中旨在通过跳过层来改善延迟的方法主要分为两类:1)早期退出机制;2)输入无关启发式方法,即令词元在预设层退出而忽略输入序列。上述两种策略均存在局限——前者无法兼容现代加速框架所必需的KV缓存机制,后者则未能捕捉不同任务间或更广义上不同输入序列间层重要性的动态变化。为同时解决这两类问题,本文提出FiRST算法,该算法通过层特异性路由器为每个输入序列自适应选择Transformer层的子集以降低推理延迟——预填充阶段的提示词将决定解码过程中跳过的层。FiRST保持与KV缓存的兼容性,在实现加速推理的同时具备质量感知能力。本方法具有模型无关性,可便捷应用于任何预训练LLM。我们的研究表明输入自适应性至关重要:事实上,针对不同任务,特定的中间层在隐表示演化过程中发挥着关键作用。大量实验表明,FiRST在显著降低延迟的同时,在质量指标上优于其他层选择策略。其性能与基准模型(无跳层)保持竞争力,在某些情况下甚至有所提升。因此,FiRST为低资源环境下的LLM部署提供了高效且前景广阔的解决方案。