LLM 课程文档

在 TRL 中实现 GRPO

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

在 TRL 中实现 GRPO

在本页中,我们将学习如何使用 Transformer Reinforcement Learning (TRL) 库实现组相对策略优化 (GRPO)。我们将专注于实际实现,使用最少的代码。

我们将使用 TRL 官方文档中的代码片段作为指导,探索 GRPO 在 TRL 的 GRPOTrainer 中体现的核心概念。

本章面向 TRL 初学者。如果您已经熟悉 TRL,您可能还想查看 GRPO 的 Open R1 实现

首先,让我们回顾一下 GRPO 算法的一些重要概念

  • 组形成:模型为每个提示生成多个补全。
  • 偏好学习:模型从奖励函数中学习,该奖励函数比较补全组。
  • 训练配置:模型使用配置来控制训练过程。

我们需要做什么来实施 GRPO?

  • 定义提示数据集。
  • 定义一个奖励函数,该函数接受补全列表并返回奖励列表。
  • 使用 GRPOConfig 配置训练过程。
  • 使用 GRPOTrainer 训练模型。

这是一个开始 GRPO 训练的最小示例

from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset

# 1. Load your dataset
dataset = load_dataset("your_dataset", split="train")


# 2. Define a simple reward function
def reward_func(completions, **kwargs):
    """Example: Reward longer completions"""
    return [float(len(completion)) for completion in completions]


# 3. Configure training
training_args = GRPOConfig(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=10,
)

# 4. Initialize and train
trainer = GRPOTrainer(
    model="your_model",  # e.g. "Qwen/Qwen2-0.5B-Instruct"
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_func,
)
trainer.train()

关键组件

1. 数据集格式

您的数据集应包含模型将响应的提示。GRPO 训练器将为每个提示生成多个补全,并使用奖励函数对它们进行比较。

2. 奖励函数

奖励函数至关重要——它决定了模型如何学习。这里有两个实用示例

# Example 1: Reward based on completion length
def reward_length(completions, **kwargs):
    return [float(len(completion)) for completion in completions]


# Example 2: Reward based on matching a pattern
import re


def reward_format(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    return [1.0 if re.match(pattern, c) else 0.0 for c in completions]

3. 训练配置

GRPOConfig 中要考虑的关键参数

training_args = GRPOConfig(
    # Essential parameters
    output_dir="output",
    num_train_epochs=3,
    num_generation=4,  # Number of completions to generate for each prompt
    per_device_train_batch_size=4,  # We want to get all generations in one device batch
    # Optional but useful
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    logging_steps=10,
    # GRPO specific (optional)
    use_vllm=True,  # Speed up generation
)

num_generation 参数对 GRPO 尤其重要,因为它定义了组大小——模型将为每个提示生成多少个不同的补全。这是与其他 RL 方法的关键区别。

  • 过小(例如 2-3):可能无法提供足够的 Diversity 进行有意义的比较
  • 推荐(4-16):在多样性和计算效率之间取得了良好的平衡
  • 更大的值:可能会改善学习,但会显著增加计算成本

应根据您的计算资源和任务的复杂性选择组大小。对于简单任务,较小的组 (4-8) 可能足够,而更复杂的推理任务可能受益于较大的组 (8-16)。

成功秘诀

  1. 内存管理:根据您的 GPU 内存调整 per_device_train_batch_sizegradient_accumulation_steps
  2. 速度:如果您的模型支持,请启用 use_vllm=True 以加快生成速度。
  3. 监控:在训练期间观察记录的指标
    • reward:补全的平均奖励
    • reward_std:奖励组内的标准差
    • kl:与参考模型的 KL 散度

奖励函数设计

DeepSeek R1 论文展示了几种有效的奖励函数设计方法,您可以将其调整用于您自己的 GRPO 实现

1. 基于长度的奖励

最容易实现的奖励之一是基于长度的奖励。您可以奖励更长的补全

def reward_len(completions, **kwargs):
    ideal_length = 20
    return [-abs(ideal_length - len(completion)) for completion in completions]

此奖励函数会惩罚过短或过长的补全,鼓励模型生成接近理想长度 20 个 token 的补全。

2. 可验证任务的基于规则的奖励

对于具有客观正确答案的任务(例如数学或编码),您可以实现基于规则的奖励函数

def problem_reward(completions, answers, **kwargs):
    """Reward function for math problems with verifiable answers
    completions: list of completions to evaluate
    answers: list of answers to the problems from the dataset
    """

    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # Extract the answer from the completion
        try:
            # This is a simplified example - you'd need proper parsing
            answer = extract_final_answer(completion)
            # Binary reward: 1 for correct, 0 for incorrect
            reward = 1.0 if answer == correct_answer else 0.0
            rewards.append(reward)
        except:
            # If we can't parse an answer, give a low reward
            rewards.append(0.0)

    return rewards

3. 基于格式的奖励

您还可以奖励正确的格式,这在 DeepSeek R1 训练中很重要

def format_reward(completions, **kwargs):
    """Reward completions that follow the desired format"""
    # Example: Check if the completion follows a think-then-answer format
    pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"

    rewards = []
    for completion in completions:
        match = re.search(pattern, completion, re.DOTALL)
        if match:
            # Check if there's substantial content in both sections
            think_content = match.group(1).strip()
            answer_content = match.group(2).strip()

            if len(think_content) > 20 and len(answer_content) > 0:
                rewards.append(1.0)
            else:
                rewards.append(
                    0.5
                )  # Partial reward for correct format but limited content
        else:
            rewards.append(0.0)  # No reward for incorrect format

    return rewards

这些示例展示了如何实现受 DeepSeek R1 训练过程启发,专注于正确性、格式和组合信号的奖励函数。

就是这样!

在下一节中,您将进行一项练习,在 TRL 中实现 GRPO。

< > 在 GitHub 上更新