Alignment is crucial for training large language models. The predominant strategy is Reinforcement Learning from Human Feedback (RLHF), with Proximal Policy Optimization (PPO) as the de-facto algorithm. Yet, PPO is known to struggle with computational inefficiency, a challenge that this paper aims to address. We identify three important properties of RLHF tasks: fast simulation, deterministic transitions, and trajectory-level rewards, which are not leveraged in PPO. Based on these properties, we develop ReMax, a new algorithm tailored for RLHF. The design of ReMax builds on the celebrated algorithm REINFORCE but is enhanced with a new variance-reduction technique. ReMax offers threefold advantages over PPO: first, it is simple to implement with just 6 lines of code. It further eliminates more than 4 hyper-parameters in PPO, which are laborious to tune. Second, ReMax reduces memory usage by about 50%. To illustrate, PPO runs out of memory when fine-tuning a Llama2-7B model on A100-80GB GPUs, whereas ReMax can support the training. Even though memory-efficient techniques (e.g., ZeRO and offload) are employed for PPO to afford training, ReMax can utilize a larger batch size to increase throughput. Third, in terms of wall-clock time, PPO is about twice as slow as ReMax per iteration. Importantly, these improvements do not sacrifice task performance. We hypothesize that these advantages can be maintained in larger-scale models.
翻译:对齐对于训练大型语言模型至关重要。主流策略是基于人类反馈的强化学习(RLHF),其中近端策略优化(PPO)是事实上的标准算法。然而,PPO存在计算效率低下的问题,这正是本文旨在解决的挑战。我们识别出RLHF任务的三个重要特性:快速模拟、确定性状态转移和轨迹级奖励,而这些特性在PPO中未被充分利用。基于这些特性,我们开发了ReMax——一种专为RLHF设计的新算法。ReMax的设计基于著名的REINFORCE算法,但通过新的方差缩减技术进行了增强。与PPO相比,ReMax具有三重优势:首先,实现简单,仅需6行代码;同时消除了PPO中需手动调优的4个以上超参数。其次,ReMax将内存使用量减少约50%。举例而言,在A100-80GB GPU上微调Llama2-7B模型时,PPO会导致内存溢出,而ReMax能够支持训练。即使采用内存高效技术(如ZeRO和卸载)使PPO能够进行训练,ReMax仍可利用更大批量大小来提高吞吐量。第三,就运行时间而言,每次迭代PPO的速度约为ReMax的两倍。重要的是,这些改进并未牺牲任务性能。我们假设这些优势可以在更大型的模型中得到保持。