Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.
翻译:大语言模型(LLM)已在多种任务中取得显著成功,但其推理过程因每个解码步骤仅生成单个令牌而受到时间和能耗的巨大限制。尽管先前的推测解码等方法通过每步生成多个令牌来缓解效率问题,但每个令牌仍由其单令牌分布生成,从而仅提升速度而未改善生成效果。相比之下,本研究同时提升了推理速度与输出效果。我们研究多令牌联合解码(MTJD),该方法在每次迭代中从多令牌的联合分布生成多个令牌,理论上能够降低困惑度并提升任务性能。然而,MTJD面临从多令牌联合分布中采样的高计算成本问题。受推测解码启发,我们提出了多令牌辅助解码(MTAD)这一新颖框架,旨在加速MTJD。MTAD利用较小的辅助模型来近似大模型的联合分布,并引入验证机制,该机制不仅保证了近似的准确性,还提升了相对于传统推测解码的解码效率。理论上,我们证明MTAD能以有界误差紧密逼近精确的MTJD。基于Llama-2和OPT模型(参数量13B至70B)在多种任务上的实证评估表明,相较于标准单令牌采样,MTAD将困惑度降低了21.2%并提升了下游任务性能。此外,MTAD相比传统推测解码方法实现了1.42倍的加速和1.54倍的能耗降低。这些结果凸显了MTAD使多令牌联合解码兼具高效性与有效性,有助于推动大语言模型向更可持续和高性能的方向部署。