LLM 课程文档

在 TRL 中实施 GRPO

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始

在 TRL 中实施 GRPO

在本页中,我们将学习如何使用 Transformer 强化学习 (TRL) 库实施 Group Relative Policy Optimization (GRPO)。我们将专注于使用最少代码的实际实施。

我们将探索 GRPO 的核心概念,因为它们体现在 TRL 的 GRPOTrainer 中,并使用来自官方 TRL 文档的片段来指导我们。

本章面向 TRL 初学者。如果您已经熟悉 TRL,您可能还想查看 GRPO 的 Open R1 实现

首先,让我们回顾一下 GRPO 算法的一些重要概念

  • 群组形成:模型为每个提示生成多个完成。
  • 偏好学习:模型从比较完成群组的奖励函数中学习。
  • 训练配置:模型使用配置来控制训练过程。

我们需要做什么来实施 GRPO?

  • 定义提示数据集。
  • 定义一个奖励函数,该函数接受完成列表并返回奖励列表。
  • 使用 GRPOConfig 配置训练过程。
  • 使用 GRPOTrainer 训练模型。

这是一个开始 GRPO 训练的最小示例

from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset

# 1. Load your dataset
dataset = load_dataset("your_dataset", split="train")


# 2. Define a simple reward function
def reward_func(completions, **kwargs):
    """Example: Reward longer completions"""
    return [float(len(completion)) for completion in completions]


# 3. Configure training
training_args = GRPOConfig(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=10,
)

# 4. Initialize and train
trainer = GRPOTrainer(
    model="your_model",  # e.g. "Qwen/Qwen2-0.5B-Instruct"
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_func,
)
trainer.train()

关键组件

1. 数据集格式

您的数据集应包含模型将响应的提示。GRPO 训练器将为每个提示生成多个完成,并使用奖励函数来比较它们。

2. 奖励函数

奖励函数至关重要 - 它决定了模型如何学习。以下是两个实际示例

# Example 1: Reward based on completion length
def reward_length(completions, **kwargs):
    return [float(len(completion)) for completion in completions]


# Example 2: Reward based on matching a pattern
import re


def reward_format(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    return [1.0 if re.match(pattern, c) else 0.0 for c in completions]

3. 训练配置

GRPOConfig 中需要考虑的关键参数

training_args = GRPOConfig(
    # Essential parameters
    output_dir="output",
    num_train_epochs=3,
    num_generation=4,  # Number of completions to generate for each prompt
    per_device_train_batch_size=4,  # We want to get all generations in one device batch
    # Optional but useful
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    logging_steps=10,
    # GRPO specific (optional)
    use_vllm=True,  # Speed up generation
)

num_generation 参数对于 GRPO 尤为重要,因为它定义了群组大小 - 模型将为每个提示生成多少个不同的完成。这是与其他 RL 方法的关键区别

  • 太小(例如,2-3):可能无法为有意义的比较提供足够的 diversity
  • 推荐(4-16):在 diversity 和计算效率之间提供良好的平衡
  • 更大的值:可能会改善学习,但会显着增加计算成本

群组大小应根据您的计算资源和任务的复杂性来选择。对于简单的任务,较小的群组(4-8)可能就足够了,而更复杂的推理任务可能会受益于较大的群组(8-16)。

成功秘诀

  1. 内存管理:根据您的 GPU 内存调整 per_device_train_batch_sizegradient_accumulation_steps
  2. 速度:如果您的模型受支持,请启用 use_vllm=True 以获得更快的生成速度。
  3. 监控:在训练期间观察记录的指标
    • reward:完成的平均奖励
    • reward_std:奖励组内的标准差
    • kl:来自参考模型的 KL 散度

奖励函数设计

DeepSeek R1 论文展示了几种有效的奖励函数设计方法,您可以将其应用于您自己的 GRPO 实施。

1. 基于长度的奖励

最容易实现的奖励之一是基于长度的奖励。您可以奖励更长的完成

def reward_len(completions, **kwargs):
    ideal_length = 20
    return [-abs(ideal_length - len(completion)) for completion in completions]

此奖励函数惩罚过短或过长的完成,鼓励模型生成接近理想长度 20 个 token 的完成。

2. 针对可验证任务的基于规则的奖励

对于具有客观正确答案的任务(如数学或编码),您可以实施基于规则的奖励函数

def problem_reward(completions, answers, **kwargs):
    """Reward function for math problems with verifiable answers
    completions: list of completions to evaluate
    answers: list of answers to the problems from the dataset
    """

    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # Extract the answer from the completion
        try:
            # This is a simplified example - you'd need proper parsing
            answer = extract_final_answer(completion)
            # Binary reward: 1 for correct, 0 for incorrect
            reward = 1.0 if answer == correct_answer else 0.0
            rewards.append(reward)
        except:
            # If we can't parse an answer, give a low reward
            rewards.append(0.0)

    return rewards

3. 基于格式的奖励

您还可以奖励正确的格式,这在 DeepSeek R1 训练中很重要

def format_reward(completions, **kwargs):
    """Reward completions that follow the desired format"""
    # Example: Check if the completion follows a think-then-answer format
    pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"

    rewards = []
    for completion in completions:
        match = re.search(pattern, completion, re.DOTALL)
        if match:
            # Check if there's substantial content in both sections
            think_content = match.group(1).strip()
            answer_content = match.group(2).strip()

            if len(think_content) > 20 and len(answer_content) > 0:
                rewards.append(1.0)
            else:
                rewards.append(
                    0.5
                )  # Partial reward for correct format but limited content
        else:
            rewards.append(0.0)  # No reward for incorrect format

    return rewards

这些示例演示了如何实施受 DeepSeek R1 训练过程启发的奖励函数,重点关注正确性、格式和组合信号。

就是这样!

在下一节中,您将跟随一个练习在 TRL 中实施 GRPO。

< > 在 GitHub 上更新