Large Language Models (LLMs) are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains. Code is available at \url{https://github.com/NVlabs/MaskLLM}.
翻译:大语言模型(LLMs)以其庞大的参数量著称,这通常会导致显著的冗余。本文提出MaskLLM,一种可学习的剪枝方法,可在LLMs中建立半结构化(或称“N:M”)稀疏性,旨在降低推理过程中的计算开销。与设计新的重要性评估准则不同,MaskLLM通过Gumbel Softmax采样将N:M模式显式建模为可学习的分布。该方法支持在大规模数据集上进行端到端训练,并具有两大显著优势:1)高质量掩码——我们的方法能有效扩展至大型数据集并学习精确的掩码;2)可迁移性——掩码分布的概率建模使得稀疏模式能够跨领域或任务进行迁移学习。我们在多种LLM(包括参数量从843M到15B不等的LLAMA-2、Nemotron-4和GPT-3)上使用2:4稀疏度评估MaskLLM,实验结果表明其性能显著优于现有最优方法。例如,在Wikitext数据集上,主流方法相比稠密模型的5.12困惑度(PPL)仅能达到10或更高的PPL,而MaskLLM仅通过学习掩码(权重冻结)即可实现显著更低的6.72 PPL。此外,MaskLLM的可学习特性支持为下游任务或领域定制掩码,从而实现2:4稀疏度的无损应用。代码发布于\url{https://github.com/NVlabs/MaskLLM}。