Text generation with Large Language Models (LLMs) is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4$\times$ speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.
翻译:大语言模型(LLMs)的文本生成因其自回归特性、庞大参数量以及有限的内存带宽而受限于内存瓶颈,常导致低令牌生成速率。推测解码已被提出作为加速LLM推理的解决方案。然而,由于现代开源LLM家族(如Llama 2 7B)中常缺乏现成的草稿模型,需训练高质量草稿模型才能通过推测解码实现推理加速。本文提出一种简单的草稿模型训练框架,用于直接对齐具备聊天能力的目标模型。利用该框架,我们训练了仅占原始模型大小1.64%的Llama 2 Chat Drafter 115M(面向Llama 2 Chat 7B及更大模型的草稿模型)。本训练框架仅包含预训练、蒸馏数据集生成以及基于知识蒸馏的微调,无需额外对齐步骤。在微调阶段,我们使用目标模型生成的指令-响应对在合理数据分布上进行蒸馏,并提出一种新的总变差距离++(TVD++)损失函数,该函数整合了源自强化学习策略梯度方法的方差缩减技术。实验结果表明,配备推测解码的Llama 2 Chat Drafter 115M在多种任务上无需任务特定微调即可实现高达2.3的块效率,相较自回归解码获得2.4倍加速。