The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU. Our library will be open-sourced in due course.
翻译:大型语言模型(LLMs)在检索任务中的扩展应用,特别是检索增强生成(RAG)场景,面临显著的内存约束挑战,尤其是在微调长提示序列时。现有开源库支持跨多GPU的全模型推理与微调,但无法满足检索上下文所需的参数高效分布需求。针对这一空白,我们提出了一种新的框架,通过分布式训练实现Llama-2模型的PEFT兼容微调。该框架创新性地利用JAX的即时编译(JIT)和张量分片技术实现高效资源管理,从而在降低内存需求的同时加速微调过程。这一突破显著提升了面向复杂RAG应用的大语言模型微调的可扩展性与可行性,即使在GPU资源受限的系统上也能实现。实验表明,与使用四块GPU的Hugging Face/DeepSpeed实现方案相比,本方案的运行时间提升了12倍以上,且每块GPU的显存消耗降低超过50%。本库将适时开源。