The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.
翻译:大语言模型(LLM)在检索任务中的扩展,特别是检索增强生成(RAG)场景,面临显著的内存限制,尤其是在微调长提示序列时。当前开源库支持多GPU上的全模型推理与微调,但无法有效分配检索上下文所需的参数。为解决此问题,我们提出一种适用于Llama-2模型的PEFT兼容微调的新型框架,利用分布式训练技术。该框架独特地采用JAX即时编译与张量分片机制实现高效资源管理,从而在降低内存需求的同时加速微调过程。该进展显著提升了LLM在复杂RAG应用中微调的可扩展性与可行性,即使在GPU资源受限的系统中也能实现。实验表明,与使用四个GPU的Hugging Face/DeepSpeed实现相比,本方案运行时性能提升超过12倍,同时每GPU显存占用减少一半以上。