LLM 课程文档
在 TRL 中实现 GRPO
并获得增强的文档体验
开始使用
在 TRL 中实现 GRPO
在本页中,我们将学习如何使用 Transformer Reinforcement Learning (TRL) 库实现组相对策略优化 (GRPO)。我们将专注于实际实现,使用最少的代码。
我们将使用 TRL 官方文档中的代码片段作为指导,探索 GRPO 在 TRL 的 GRPOTrainer 中体现的核心概念。
本章面向 TRL 初学者。如果您已经熟悉 TRL,您可能还想查看 GRPO 的 Open R1 实现。
首先,让我们回顾一下 GRPO 算法的一些重要概念
- 组形成:模型为每个提示生成多个补全。
- 偏好学习:模型从奖励函数中学习,该奖励函数比较补全组。
- 训练配置:模型使用配置来控制训练过程。
我们需要做什么来实施 GRPO?
- 定义提示数据集。
- 定义一个奖励函数,该函数接受补全列表并返回奖励列表。
- 使用 GRPOConfig 配置训练过程。
- 使用 GRPOTrainer 训练模型。
这是一个开始 GRPO 训练的最小示例
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
# 1. Load your dataset
dataset = load_dataset("your_dataset", split="train")
# 2. Define a simple reward function
def reward_func(completions, **kwargs):
"""Example: Reward longer completions"""
return [float(len(completion)) for completion in completions]
# 3. Configure training
training_args = GRPOConfig(
output_dir="output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
logging_steps=10,
)
# 4. Initialize and train
trainer = GRPOTrainer(
model="your_model", # e.g. "Qwen/Qwen2-0.5B-Instruct"
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
)
trainer.train()
关键组件
1. 数据集格式
您的数据集应包含模型将响应的提示。GRPO 训练器将为每个提示生成多个补全,并使用奖励函数对它们进行比较。
2. 奖励函数
奖励函数至关重要——它决定了模型如何学习。这里有两个实用示例
# Example 1: Reward based on completion length
def reward_length(completions, **kwargs):
return [float(len(completion)) for completion in completions]
# Example 2: Reward based on matching a pattern
import re
def reward_format(completions, **kwargs):
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
return [1.0 if re.match(pattern, c) else 0.0 for c in completions]
3. 训练配置
GRPOConfig
中要考虑的关键参数
training_args = GRPOConfig(
# Essential parameters
output_dir="output",
num_train_epochs=3,
num_generation=4, # Number of completions to generate for each prompt
per_device_train_batch_size=4, # We want to get all generations in one device batch
# Optional but useful
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
# GRPO specific (optional)
use_vllm=True, # Speed up generation
)
num_generation
参数对 GRPO 尤其重要,因为它定义了组大小——模型将为每个提示生成多少个不同的补全。这是与其他 RL 方法的关键区别。
- 过小(例如 2-3):可能无法提供足够的 Diversity 进行有意义的比较
- 推荐(4-16):在多样性和计算效率之间取得了良好的平衡
- 更大的值:可能会改善学习,但会显著增加计算成本
应根据您的计算资源和任务的复杂性选择组大小。对于简单任务,较小的组 (4-8) 可能足够,而更复杂的推理任务可能受益于较大的组 (8-16)。
成功秘诀
- 内存管理:根据您的 GPU 内存调整
per_device_train_batch_size
和gradient_accumulation_steps
。 - 速度:如果您的模型支持,请启用
use_vllm=True
以加快生成速度。 - 监控:在训练期间观察记录的指标
reward
:补全的平均奖励reward_std
:奖励组内的标准差kl
:与参考模型的 KL 散度
奖励函数设计
DeepSeek R1 论文展示了几种有效的奖励函数设计方法,您可以将其调整用于您自己的 GRPO 实现
1. 基于长度的奖励
最容易实现的奖励之一是基于长度的奖励。您可以奖励更长的补全
def reward_len(completions, **kwargs):
ideal_length = 20
return [-abs(ideal_length - len(completion)) for completion in completions]
此奖励函数会惩罚过短或过长的补全,鼓励模型生成接近理想长度 20 个 token 的补全。
2. 可验证任务的基于规则的奖励
对于具有客观正确答案的任务(例如数学或编码),您可以实现基于规则的奖励函数
def problem_reward(completions, answers, **kwargs):
"""Reward function for math problems with verifiable answers
completions: list of completions to evaluate
answers: list of answers to the problems from the dataset
"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract the answer from the completion
try:
# This is a simplified example - you'd need proper parsing
answer = extract_final_answer(completion)
# Binary reward: 1 for correct, 0 for incorrect
reward = 1.0 if answer == correct_answer else 0.0
rewards.append(reward)
except:
# If we can't parse an answer, give a low reward
rewards.append(0.0)
return rewards
3. 基于格式的奖励
您还可以奖励正确的格式,这在 DeepSeek R1 训练中很重要
def format_reward(completions, **kwargs):
"""Reward completions that follow the desired format"""
# Example: Check if the completion follows a think-then-answer format
pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"
rewards = []
for completion in completions:
match = re.search(pattern, completion, re.DOTALL)
if match:
# Check if there's substantial content in both sections
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
if len(think_content) > 20 and len(answer_content) > 0:
rewards.append(1.0)
else:
rewards.append(
0.5
) # Partial reward for correct format but limited content
else:
rewards.append(0.0) # No reward for incorrect format
return rewards
这些示例展示了如何实现受 DeepSeek R1 训练过程启发,专注于正确性、格式和组合信号的奖励函数。
就是这样!
在下一节中,您将进行一项练习,在 TRL 中实现 GRPO。
< > 在 GitHub 上更新