Alignment is of critical importance for training large language models (LLMs). The predominant strategy to address this is through Reinforcement Learning from Human Feedback (RLHF), where PPO serves as the de-facto algorithm. Yet, PPO is known to suffer from computational inefficiency, a challenge that this paper aims to address. We identify three important properties in RLHF tasks: fast simulation, deterministic transitions, and trajectory-level rewards, which are not leveraged in PPO. Based on such observations, we develop a new algorithm tailored for RLHF, called ReMax. The algorithm design of ReMax is built on a celebrated algorithm REINFORCE but is equipped with a new variance-reduction technique. Our method has three-fold advantages over PPO: first, it saves about 50% memory usage in principle. As a result, PPO runs out-of-memory when fine-tuning a Llama2 (7B) model on 8xA100-40GB GPUs, whereas ReMax can afford training. This memory improvement is achieved by removing the value model in PPO. Second, ReMax is simple to implement and removes many hyper-parameters in PPO, which are scale-sensitive and laborious to tune. Third, on GPT2 (137M), we observe 2.2x speed-up in terms of wall-clock time. Importantly, the above computational improvements do not sacrifice the performance. We hypothesize these advantages can be maintained in larger-scaled models. Our implementation of ReMax is available at https://github.com/liziniu/ReMax
翻译:对齐对于训练大语言模型(LLMs)至关重要。解决这一问题的主流策略是通过基于人类反馈的强化学习(RLHF),其中PPO是事实上的算法。然而,PPO以计算效率低下著称,本文旨在解决这一挑战。我们识别出RLHF任务中的三个重要特性:快速模拟、确定性转移和轨迹级奖励,这些特性在PPO中未被利用。基于这些观察,我们开发了一种专为RLHF定制的新算法,称为ReMax。ReMax的算法设计建立在著名算法REINFORCE之上,但配备了一种新的方差缩减技术。我们的方法相比PPO具有三重优势:首先,原则上可节省约50%的内存使用。因此,当在8xA100-40GB GPU上微调Llama2(7B)模型时,PPO会因内存不足而失败,而ReMax可以支持训练。这一内存改进是通过移除PPO中的价值模型实现的。其次,ReMax易于实现,并移除了PPO中许多对规模敏感且难以调优的超参数。第三,在GPT2(137M)上,我们观察到2.2倍的实际时间加速。重要的是,上述计算改进并未牺牲性能。我们假设这些优势可以在更大规模的模型中得以保持。ReMax的实现代码可在https://github.com/liziniu/ReMax获取。