Large language models have driven significant progress in natural language processing, but their deployment requires substantial compute and memory resources. As models scale, compression techniques become essential for balancing model quality with computational efficiency. Structured pruning, which removes less critical components of the model, is a promising strategy for reducing complexity. However, one-shot pruning often results in significant quality degradation, particularly in tasks requiring multi-step reasoning. To recover lost quality, supervised fine-tuning (SFT) is commonly applied, but it can lead to catastrophic forgetting by shifting the model's learned data distribution. Therefore, addressing the degradation from both pruning and SFT is essential to preserve the original model's quality. In this work, we propose self-data distilled fine-tuning to address these challenges. Our approach leverages the original, unpruned model to generate a distilled dataset that preserves semantic richness and mitigates catastrophic forgetting by maintaining alignment with the base model's knowledge. Empirically, we demonstrate that self-data distillation consistently outperforms standard SFT, improving average accuracy by up to 8% on the HuggingFace OpenLLM Leaderboard v1. Specifically, when pruning 6 decoder blocks on Llama3.1-8B Instruct (i.e., 32 to 26 layers, reducing the model size from 8.03B to 6.72B parameters), our method retains 91.2% of the original model's accuracy compared to 81.7% with SFT, while reducing real-world FLOPs by 16.30%. Furthermore, our approach scales effectively across datasets, with the quality improving as the dataset size increases.
翻译:大语言模型推动了自然语言处理的显著进展,但其部署需要大量的计算和内存资源。随着模型规模的扩大,压缩技术对于平衡模型质量与计算效率变得至关重要。结构化剪枝通过移除模型中较不关键的组件,是降低复杂性的有效策略。然而,一次性剪枝通常会导致显著的质量下降,尤其是在需要多步推理的任务中。为了恢复损失的质量,通常会应用监督微调,但这可能因改变模型已学习的数据分布而导致灾难性遗忘。因此,解决剪枝和微调带来的质量下降问题对于保持原始模型的质量至关重要。在本工作中,我们提出自蒸馏数据微调方法来应对这些挑战。我们的方法利用原始未剪枝模型生成蒸馏数据集,该数据集保留了语义丰富性,并通过与基础模型知识保持一致来缓解灾难性遗忘。实验表明,自蒸馏数据微调在HuggingFace OpenLLM Leaderboard v1上持续优于标准监督微调,平均准确率提升高达8%。具体而言,在Llama3.1-8B Instruct模型上剪除6个解码器块(即从32层减少到26层,模型参数量从8.03B降至6.72B)时,我们的方法保留了原始模型91.2%的准确率,而监督微调仅保留81.7%,同时实际FLOPs降低了16.30%。此外,我们的方法在不同数据集上均能有效扩展,随着数据集规模的增大,模型质量持续提升。