The inference process in Large Language Models (LLMs) is often limited due to the absence of parallelism in the auto-regressive decoding process, resulting in most operations being restricted by the memory bandwidth of accelerators. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa introduces only minimal overhead in terms of single-step latency while substantially reducing the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.
翻译:大语言模型(LLM)的推理过程常因自回归解码缺乏并行性而受限,导致多数运算受限于加速器的内存带宽。尽管投机解码等方法已被提出以解决此问题,但其实际应用因单独草稿模型的获取与维护面临挑战。本文提出Medusa这一高效方法,通过添加额外解码头并行预测多个后续token,增强LLM推理能力。基于树状注意力机制,Medusa在每次解码步骤中构建多个候选延续路径并同步验证。通过利用并行处理,Medusa在单步延迟方面仅引入极小开销,同时显著减少所需解码步数。我们针对不同应用场景提出两种微调方案:Medusa-1直接基于冻结的骨干LLM进行微调,实现无损推理加速;Medusa-2则与骨干LLM联合微调,提升Medusa头的预测精度和加速比,但需采用特殊训练方法以保持骨干模型能力。此外,我们提出多项扩展以增强Medusa的实用性,包括适应无训练数据场景的自蒸馏方法,以及兼顾生成质量与接受率的典型接受方案。我们在不同规模模型和训练流程上评估Medusa。实验表明,Medusa-1可在不牺牲生成质量的前提下实现超过2.2倍加速,而Medusa-2进一步将加速比提升至2.3-3.6倍。