LLM 课程文档
DeepSeekMath 中群组相对策略优化 (GRPO) 的高级理解
并获得增强的文档体验
开始使用
DeepSeekMath 中群组相对策略优化 (GRPO) 的高级理解
本节深入探讨 GRPO 的技术和数学细节。它由 Shirin Yamani 撰写。
让我们加深对 GRPO 的理解,以便改进模型的训练过程。
GRPO 通过比较同一代生成组内的响应来直接评估模型生成的响应,从而优化策略模型,而不是训练单独的价值模型(Critic)。这种方法显著降低了计算成本!
GRPO 可以应用于任何可以确定响应正确性的可验证任务。例如,在数学推理中,可以通过将其与真实答案进行比较来轻松验证响应的正确性。
在深入探讨技术细节之前,让我们从宏观层面可视化 GRPO 的工作原理
现在我们对 GRPO 有了可视化的概述,让我们逐步分解 GRPO 的工作原理。
GRPO 算法
GRPO 的核心创新在于其同时评估和学习多个生成响应的方法。它不依赖于单独的奖励模型,而是比较同一组内的输出,以确定哪些应该被强化。
让我们详细了解算法的每个步骤
步骤 1:群组抽样
第一步是为每个问题生成多个可能的答案。这创建了一个可以相互比较的多元化输出集。
对于每个问题,模型将生成个输出来自训练过的策略(组大小):{ },其中每个代表模型的一次补全。
示例
为了使之具体化,让我们看一个简单的算术问题
问题 :
输出 :
请注意,一些生成的答案是正确的 (14),而另一些是错误的(16 或 10)。这种多样性对于下一步至关重要。
步骤 2:优势计算
一旦我们有了多个响应,我们需要一种方法来确定哪些响应优于其他响应。这就是优势计算的用武之地。
奖励分配
首先,我们为每个生成的响应分配一个奖励分数。在本例中,我们将使用奖励模型,但正如我们在上一节中所学的,我们可以使用任何返回奖励的函数。
根据正确性为每个生成的响应分配 RM 分数 (例如,正确响应为 1,错误响应为 0) 然后对于每个计算以下优势值。
优势值公式
GRPO 的关键洞察力在于我们不需要质量的绝对度量 - 我们可以比较同一组内的输出。这是通过标准化完成的
示例
继续我们的算术示例,对于上面的同一个例子,假设我们有 8 个响应,其中 4 个是正确的,其余是错误的,因此:
指标 | 值 |
---|---|
组平均值 | |
标准差 | |
优势值(正确响应) | |
优势值(错误响应) |
解释
现在我们已经计算了优势值,让我们理解它们的含义
这种标准化(即加权)允许模型评估每个响应的相对性能,引导优化过程朝着优于平均水平的有利响应(高奖励)发展,并阻止那些较差的响应。例如,如果,那么是比其组内平均水平更好的响应;并且如果,那么那么响应的质量低于平均水平(即质量/性能差)。
对于上面的例子,如果那么在优化步骤中,其生成概率将会增加。
有了我们计算的优势值,我们现在准备好更新策略了。
步骤 3:策略更新
最后一步是使用这些优势值来更新我们的模型,使其在未来更有可能生成好的响应。
策略更新的目标函数是
这个公式乍一看可能令人生畏,但它由几个组件构建而成,每个组件都服务于重要的目的。 让我们逐一分解它们。
目标函数的关键组成部分
GRPO 更新函数结合了几种技术,以确保稳定有效的学习。 让我们检查每个组件
1. 概率比率
概率比率定义为
直观地说,该公式比较了新模型的响应概率与旧模型的响应概率的差异程度,同时结合了对改善预期结果的响应的偏好。
解释
- 如果,新模型为响应分配了更高的概率比旧模型。
- 如果,新模型为
分配了较低的概率。这个比率允许我们控制模型在每个步骤中改变多少,这引导我们到下一个组件。
2. 裁剪函数
裁剪函数定义为
限制上述比率在以下范围内:以避免/控制剧烈的变化或疯狂的更新,以及避免离旧策略太远。换句话说,它限制了概率比率可以增加的幅度,从而通过避免使新模型偏离旧模型太远的更新来帮助维持稳定性。
示例 (ε = 0.2)
让我们看两个不同的场景,以更好地理解这个裁剪函数。
- 案例 1:如果新策略对于特定响应的概率为 0.9,而旧策略的概率为 0.5,则意味着这个响应正在被新策略加强,以获得更高的概率,但在受控的限制范围内,即裁剪,以限制其幅度,避免剧烈变化 -(上限限制为 1.2)
- 案例 2:如果新策略不赞成某个响应(较低的概率,例如 0.2),这意味着如果该响应不利,则增加可能是错误的,模型将受到惩罚。 -(下限限制为 0.8)
解释
- 该公式鼓励新模型支持旧模型低估的响应,如果这些响应能改善结果。
- 如果旧模型已经高概率地支持某个响应,新模型仍然可以加强它,但只能在受控的限制内,。
- 如果旧模型高估了表现不佳的响应,则新模型会被阻止维持高概率。
- 因此,直观地说,通过结合概率比率,目标函数确保策略的更新与优势成正比,同时被缓和以防止剧烈变化。 T
虽然裁剪函数有助于防止剧烈变化,但我们需要一个额外的安全措施来确保我们的模型不会过度偏离其原始行为。
3. KL 散度
KL 散度项是
在 KL 散度项中,基本上是更新前模型的输出,per_token_logps
和是新模型的输出,new_per_token_logps
。理论上,最小化 KL 散度是为了防止模型在优化过程中过度偏离其原始行为。这有助于在基于奖励信号提高性能和保持连贯性之间取得平衡。在这种情况下,最小化 KL 散度降低了模型生成无意义文本的风险,或者在数学推理的情况下,产生极其不正确的答案的风险。
解释
- KL 散度惩罚使模型的输出接近其原始分布,防止极端偏移。
- 模型不会漂移到完全不合理的输出,而是会在允许一定探索的同时改进其理解。
数学定义
对于那些对数学细节感兴趣的人,让我们看一下正式的定义。
回顾一下,KL 距离定义如下:在 RLHF 中,感兴趣的两个分布通常是新模型版本的分布 P(x) 和参考策略 Q(x) 的分布。
β 参数的作用
系数控制我们强制执行 KL 散度约束的强度。
- 更高的 β 值(更强的 KL 惩罚)
- 对策略更新的约束更多。模型保持接近其参考分布。
- 可能会减慢适应速度:模型可能难以探索更好的响应。
- 更低的 β 值(更弱的 KL 惩罚)
- 策略更新的自由度更高:模型可以更多地偏离参考。
- 更快的适应速度,但存在不稳定性风险:模型可能会学习奖励利用行为。
- 过度优化风险:如果奖励模型存在缺陷,策略可能会生成无意义的输出。
- 原始 DeepSeekMath 论文将此值设置为
既然我们了解了 GRPO 的组成部分,让我们看看它们在一个完整的示例中是如何协同工作的。
GRPO 的工作示例
为了巩固我们对 GRPO 的理解,让我们从头到尾完成一个完整的示例。
示例问题
步骤 1:群组抽样
首先,我们从模型中生成多个响应。
生成个响应,其中 个是正确答案 ( \( 14, \text{reward=} 1 \) ),而个是不正确的,因此
步骤 2:优势计算
接下来,我们计算优势值以确定哪些响应优于平均水平。
统计 | 值 |
---|---|
组平均值 | |
标准差 | |
优势值(正确响应) | |
优势值(错误响应) |
步骤 3:策略更新
最后,我们更新模型以强化正确的响应。
- 假设旧策略 ( \( \pi{\theta{old}} \) ) 对于正确输出的概率是,新策略将其提高到,那么
- 然后,当目标函数被重新加权时,模型倾向于强化正确输出的生成,并且限制了与参考策略的偏差。
在理论理解到位后,让我们看看 GRPO 如何在代码中实现。
实现示例
让我们在一个实际的例子中将所有内容放在一起。以下代码演示了如何在 PyTorch 中实现 GRPO。
1. 加载模型并生成响应
首先,我们需要加载一个模型并为给定的问题生成多个响应。
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and tokenizer
model_name = "Qwen/Qwen2-Math-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Input prompt
prompt = "Solve y = 2x + 1 for x = 2, y = " # Correct answer: 5
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device) # Shape: (1, prompt_len)
attention_mask = inputs["attention_mask"].to(device)
# Step 1: Generate 8 responses (B = 2 groups, G = 4 responses per group)
batch_size, num_generations = 2, 4
outputs = model.generate(
input_ids=input_ids, # Shape: (1, prompt_len)
attention_mask=attention_mask,
max_new_tokens=1, # seq_len = 1 (single token per response)
num_return_sequences=batch_size * num_generations, # 8 responses total
do_sample=True,
top_k=10,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
)
此初始生成(在任何步骤之前)将输出类似这样的内容
Output 1: 5.0 Output 2: 6.0 Output 3: 7.0 Output 4: 5.0 Output 5: 10.0 Output 6: 2.0 Output 7: 5.0 Output 8: 5.0
2. 计算奖励
现在,我们需要确定哪些响应是正确的,并相应地分配奖励。
使用 GRPO,对于相同的示例提示,我们生成多个补全结果。例如,对于我们的提示 "Solve y = 2x + 1 for x = 2, y = "
和 Solve y = 2x + 1 for x = 4, y = "
,我们有两个针对给定提示生成的输出组,一组例如是
[5, 6, 7, 5]
,另一组是[10, 2, 9, 9]
,而正确答案是 5 和 9。
请注意,在实践中,这些奖励分数是通过基于规则的奖励函数实现的,该函数根据响应的正确性分配奖励,或者通过更复杂的基于神经网络的模型来实现,该模型可以训练为根据响应的正确性或两者的混合来分配奖励。但为了简单起见,假设如果响应正确,我们的每个响应的奖励为 1,如果响应错误,则奖励为 0;因此:
reward_1 = [1, 0, 0, 1]
reward_2 = [0, 0, 1, 1]
接下来,我们获得奖励的组均值和标准差;
# Shape: (B * G,) = (8,) bc we have 2 groups of 4 generations that we flatten
rewards = torch.tensor([1, 0, 0, 1, 0, 0, 1, 1], dtype=torch.float32)
num_generations = 4
# Group rewards: Shape (B, G) = 2, 4)
rewards_grouped = rewards.view(-1, num_generations)
# Mean per group: Shape (B,) = (2,)
mean_grouped_rewards = rewards_grouped.mean(dim=1)
# Std per group: Shape (B,) = (2,)
std_grouped_rewards = rewards_grouped.std(dim=1)
# Broadcast to match rewards and normalize: Shape (B * G,) = (8,)
# why we need to broadcast? because we need to calculate the advantage values for each response within the group
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(num_generations, dim=0)
这将输出
Grouped Rewards: tensor([[1., 0., 0., 1.], [0., 0., 1., 1.]]) Mean per group: tensor([0.5000, 0.5000]) Std per group: tensor([0.5774, 0.5774]) Broadcasted Mean: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]) Broadcasted Std: tensor([0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774])
现在我们可以计算每个响应的优势值了
# Advantages: Shape (B * G,) = (8,)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
这将输出
Advantages: tensor([ 0.8659, -0.8660, -0.8660, 0.8659, -0.8660, -0.8660, 0.8659, 0.8659])
这来自上面的优势公式,所以
For reward_1 = [1, 0, 0, 1]: 1 - 0.5 / 0.5774 ≈ 0.8659 0 - 0.5 / 0.5774 ≈ -0.8660 For reward_2 = [0, 0, 1, 1]: Same pattern.
但是,这里的形状是 (B*G,) = (8,)
,但在实践中,我们需要形状为 (B, G) = (2, 4)
以匹配 logits 形状,对吗?因此,我们需要对优势张量进行unsqueeze操作,使其形状为 (B*G, 1) = (8, 1)
以匹配 logits 形状。
# Shape (B * G, 1) = (8, 1) to match the logits shape
advantages = advantages.unsqueeze(1)
这将输出
Advantages: tensor([[ 0.8659], [-0.8660], [-0.8660], [ 0.8659], [-0.8660], [-0.8660], [ 0.8659], [ 0.8659]])
现在我们一切就绪,让我们继续下一步,基于优势值更新策略模型。
3. 更新策略
最后,我们使用优势值来更新我们的模型。
# Compute probability ratio between new and old policies
ratio = torch.exp(
new_per_token_logps - per_token_logps
) # Shape: (B*G, seq_len) seq_len is the length of the output i.e. the num of generated tokens so here for simplicity let's assume it is 1 # (8, 1)
请注意,per_token_logps
可以通过将生成的输出传递给模型并获取 logits,然后应用 softmax 函数以获得概率 F.softmax(logits, dim=-1)
来实现。
# Clipping Function
eps = self.cliprange # e.g. 0.2
pg_losses1 = -advantages * ratio # Shape: (B*G, seq_len) #(8, 1)
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - eps, 1.0 + eps
) # Shape: (B*G, seq_len) #(8, 1)
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, seq_len) #(8, 1)
# Now Combine with KL penalty # Shape: (B*G, seq_len) #(8, 1)
per_token_loss = pg_loss_max + self.beta * per_token_kl
per_token_kl
也可以按如下方式计算
# Shape: (B*G, seq_len) #(8, 1)
per_token_kl = F.kl_div(
F.log_softmax(new_per_token_logps, dim=-1),
F.softmax(per_token_logps, dim=-1),
reduction="none",
).sum(dim=-1, keepdim=True)
完整的示例可以在这里找到。GRPO 也由优秀的 TRL 团队实现,您可以查看 TRL/GRPO_trainer 的实现以获取更多详细信息。
总结和后续步骤
恭喜!您现在已经了解了组相对策略优化 (GRPO)。为了回顾我们涵盖的内容:
- GRPO 比较组内的多个输出,以确定哪些输出优于其他输出,而无需单独的价值模型。
- 优势计算标准化奖励,以识别哪些响应高于或低于平均水平。
- 策略更新使用带有 KL 散度惩罚的裁剪目标函数,以确保稳定的学习。
这种方法对于数学推理任务尤其强大,在这些任务中,正确性可以客观地验证。与需要单独评论家模型的传统 RLHF 方法相比,GRPO 方法允许更有效的训练。
当您继续探索 GRPO 时,请考虑尝试不同的组大小、奖励函数和 KL 惩罚系数,以了解它们如何影响模型的性能。
祝您训练愉快!🚀