Alignment is of critical importance for training large language models (LLMs). The predominant strategy to address this is through Reinforcement Learning from Human Feedback (RLHF), where PPO serves as the de-facto algorithm. Yet, PPO is known to suffer from computational inefficiency, which is a challenge that this paper aims to address. We identify three important properties in RLHF tasks: fast simulation, deterministic transitions, and trajectory-level rewards, which are not leveraged in PPO. Based on such observations, we develop a new algorithm tailored for RLHF, called ReMax. The algorithm design of ReMax is built on a celebrated algorithm REINFORCE but is equipped with a new variance-reduction technique. Our method has three-fold advantages over PPO: first, ReMax is simple to implement and removes many hyper-parameters in PPO, which are scale-sensitive and laborious to tune. Second, ReMax saves about 50% memory usage in principle. As a result, PPO runs out-of-memory when fine-tuning a Llama2 (7B) model on 8xA100-40GB GPUs, whereas ReMax can afford training. This memory improvement is achieved by removing the value model in PPO. Third, based on our calculations, we find that even assuming PPO can afford the training of Llama2 (7B), it would still run about 2x slower than ReMax. This is due to the computational overhead of the value model, which does not exist in ReMax. Importantly, the above computational improvements do not sacrifice the performance. We hypothesize these advantages can be maintained in larger-scaled models. Our implementation of ReMax is available at https://github.com/liziniu/ReMax
翻译:[translated abstract in Chinese]
对齐对于训练大型语言模型(LLMs)至关重要。解决这一问题的主流策略是基于人类反馈的强化学习(RLHF),其中PPO是事实上的标准算法。然而,PPO已知存在计算效率低下的问题,这正是本文试图解决的挑战。我们识别出RLHF任务中的三个重要特性:快速模拟、确定性转移和轨迹级奖励,而这些特性在PPO中未被充分利用。基于这些观察,我们开发了一种专为RLHF定制的新算法,称为ReMax。ReMax的算法设计基于著名的REINFORCE算法,但配备了一种新的方差缩减技术。与PPO相比,我们的方法具有三方面优势:首先,ReMax实现简单,去除了PPO中许多对尺度敏感且调优费力的超参数。其次,ReMax原则上节省了约50%的内存使用。因此,当使用8块A100-40GB GPU微调Llama2(7B)模型时,PPO会内存不足,而ReMax则能负担训练。这一内存改进是通过去除PPO中的值模型实现的。第三,根据我们的计算,即使假设PPO能负担Llama2(7B)的训练,其运行速度仍会比ReMax慢约2倍。这是由于ReMax中不存在值模型的计算开销。重要的是,上述计算改进并未牺牲性能。我们假设这些优势在更大规模的模型中也能保持。我们的ReMax实现可在https://github.com/liziniu/ReMax获取。