Mini-R1:复现Deepseek R1的“顿悟时刻”——RL教程

社区文章 发布于2025年1月31日

此文章由Philipp Schmid撰写,最初发布于philschmid.de,代码可在此处找到。

Deepseek R1的发布震惊了业界。为什么?DeepSeek-R1是一个开放模型,它在复杂推理任务中与OpenAI的o1相媲美,它通过组相对策略优化(GRPO)和以RL为中心的多阶段训练方法引入。他们不仅发布了模型,还发布了关于其实现方式的研究论文。

论文中,他们描述了在使用纯强化学习训练模型时的“顿悟时刻”。在此阶段,DeepSeek-R1-Zero(DeepSeek-R1的首次测试)学会了通过重新评估其初始方法来为问题分配更多思考时间,而无需任何人类反馈或描述如何操作的数据。他们将此描述为“顿悟时刻”,因为

这种行为不仅证明了模型日益增长的推理能力,也是强化学习如何带来意想不到的复杂结果的引人入胜的例子。

在这篇博客文章中,我们希望使用组相对策略优化(GRPO)和倒计时游戏来重现DeepSeek-R1的小“顿悟时刻”。我们将使用强化学习训练一个开放模型,尝试教它自我验证和搜索能力,以解决倒计时游戏。倒计时游戏是一个数字谜题,玩家使用一组随机抽取的数字和基本算术运算(+、-、×、÷)来达到或尽可能接近目标数字。

Target Number: 952
Available Numbers: 25, 50, 75, 100, 3, 6

(100 × (3 × 3)) + (50 + 6 / 3) = 952

这篇博客文章重点介绍了使用Deepspeed和vLLM进行分布式训练。它在4个NVIDIA H100 GPU上运行。

  1. 设置开发环境
  2. 使用Deepspeed和vLLM的GRPO分布式训练示例
  3. 结果和训练观察

注意:此博客的灵感来自Jiayi Pan,他最初探索了这个想法并用一个小模型验证了它。

但在我们开始之前,让我们先了解一下组相对策略优化(GRPO)并了解其工作原理。

组相对策略优化(GRPO)

组相对策略优化(GRPO)是一种用于提高LLM推理能力的强化学习算法。它是在DeepSeekMath论文中在数学推理背景下引入的。GRPO通过消除对值函数模型的需求来修改传统的近端策略优化(PPO)。相反,它从组分数中估计基线,减少内存使用和计算开销。GRPO,现在也由Qwen团队使用,可以与基于规则/二元的奖励以及通用奖励模型一起使用,以提高模型的有用性。

  1. 采样:使用当前策略为每个提示生成多个输出
  2. 奖励评分:使用奖励函数(可以是基于规则或基于结果的)对每个生成进行评分
  3. 优势计算:生成输出的平均奖励用作基线。然后,组内每个解决方案的优势相对于此基线进行计算。奖励在组内进行归一化。
  4. 策略优化:策略尝试最大化GRPO目标,其中包括计算出的优势和KL散度项。这与PPO在奖励中实现KL项的方式不同。

image/png

1. 设置开发环境

我们的第一步是安装Hugging Face库和Pytorch、vllm、trl、transformers和datasets。如果您还没有听说过trl,请不要担心。它是一个在transformers和datasets之上构建的新库,可以更轻松地微调、rlhf和对齐开放LLM。

# Install Pytorch & other libraries, make sure to match your GPU driver version
%pip install "torch==2.5.1" tensorboard "setuptools<71.0.0"  --index-url https://download.pytorch.org/whl/cu121

# Install flash-attn
%pip install flash-attn 

# Install Hugging Face libraries
%pip install  --upgrade \
  "transformers==4.48.1" \
  "datasets==3.1.0" \
  "accelerate==1.3.0" \
  "hf-transfer==0.1.9" \
  "deepspeed==0.15.4" \
  "trl==0.14.0"

# install vLLM 
%pip install "vllm==0.7.0"

注意:您可能需要重启内核才能使用更新的包。

我们将使用Hugging Face Hub作为远程模型版本控制服务。这意味着我们将在训练期间自动将模型、日志和信息推送到Hub。为此,您必须在Hugging Face上注册。拥有账户后,我们将使用huggingface_hub包中的login工具登录我们的账户并将我们的令牌(访问密钥)存储在磁盘上。

from huggingface_hub import login

login(token="", add_to_git_credential=True) # ADD YOUR TOKEN HERE

2. 使用Deepspeed和vLLM的GRPO分布式训练示例

我们将使用Jiayi-Pan/Countdown-Tasks-3to4数据集,其中包含3到4个数字和解决方案的样本。作为模型,我们将使用Qwen/Qwen2.5-3B-Instruct,这是一个3B参数的指令微调模型。这使得展示“顿悟时刻”更容易,因为它已经可以遵循指令。Jiayi-Pan研究发现,模型需要达到一定的质量才能学习推理过程,从>1.5B参数开始。

TRL通过专用的GRPOTrainer支持组相对策略优化(GRPO),用于根据偏好数据对LLM进行对齐,如DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models中所述。GRPOTrainertransformers库中Trainer的子类,并支持所有相同的功能,包括日志记录、检查点、分布式训练和参数高效微调(PEFT)。

GRPOTrainer支持通用结果奖励模型(ORM)和自定义奖励函数,可用于实现基于规则的奖励模型。在Deepseek R1论文中,他们实现了基于规则的奖励模型来验证生成解决方案的正确性。在我们的示例中,我们将采用类似的方法,创建2个奖励函数,它们将

  1. 格式奖励:检查生成的格式是否正确<think> [思考] </think><answer> [答案] </answer>
  2. 准确性奖励:从<answer>标签中提取方程式,并根据目标和每个数字是否只使用一次来评估它。

注意:我们示例中正确的<answer>包括方程式,例如<answer> 55 + 36 - 7 - 19 </answer>

Hugging Face TRL增加了对使用Deepspeed进行分布式训练和使用vLLM进行更快生成。我准备了一个run_r1_grpo.py脚本和一个receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml配置文件来运行训练。

此配置已在具有4个H100 80GB GPU的节点上进行测试和验证,其中单步大约需要45-60秒,因为我们可以利用vLLM进行生成,并利用DeepSpeed进行分布式训练。因此,我们需要确保将num_processes正确设置为您拥有的GPU数量减1,因为最后一个GPU将用于vLLM进行生成。如果您使用更多GPU,则需要将配置文件中的vllm_device更改为最后一个索引GPU,例如,如果您有8个GPU,则需要设置vllm_device=7并将num_processes设置为7。

运行训练的命令

accelerate launch --num_processes 3 --config_file configs/accelerate_configs/deepspeed_zero3.yaml scripts/run_r1_grpo.py --config receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml

通过优化的分布式训练,在4个H100 80GB GPU上,每个样本生成8个样本的单步大约需要45-60秒。450步的完整训练大约需要6小时。

3. 结果和训练观察

脚本将随机完成保存到completion_samples文件夹中,您可以使用该文件夹检查模型的进度。它包括completion_samples.txtsuccess_completion_samples.txtcompletion_samples.txt包含所有完成,而success_completion_samples.txt则正确解决了方程式。下面您可以找到关于性能如何随时间变化的有趣训练观察,以及Tensorboard日志和成功的推理样本。

每25步的检查点模型可在philschmid/qwen-2.5-3b-r1-countdown找到。

超参数

我开始实验时使用了DeepSeekMath论文中的超参数,学习率为1e-6,beta(KL系数)为0.04,这导致在大约150步后训练运行不稳定。我进行了一些小型消融实验,并将学习率降低到5e-7,beta降低到0.001,这基于OpenRLHF的测试。我无法测试将num_generations从8增加到64会对训练产生什么影响。64是DeepSeekMath论文中使用的生成值。所有其他参数可在grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml配置文件中找到。

训练观察:

  • 大约50步时,模型学会了正确的格式<think>...</think>\n<answer>...</answer>
  • 在100步时,解决方程式的成功率约为25%。模型开始用文字“推理”,请参阅下面的示例。
  • 在200步时,性能似乎收敛得更慢,我们成功率约为40%。模型开始学习一种新的“格式”,它以类似于您编写程序的方式解决方程式,通过尝试不同的组合并检查结果,请参阅“200到450步之间成功推理样本”。
  • 在450步时,我们解决方程式的成功率为50%。性能仍在缓慢提高,模型保持了从200步开始的新格式。

我有3个潜在的假设,为什么模型会从“文字推理”转向“程序化执行”

  1. Qwen 2.5 3B不够强大或太小,Deepseek提到您需要一个非常强大的基础模型。
  2. 奖励函数定义不够好,它奖励黑客模型来解决方程式。我们可以尝试强制它使用文字,例如,具有数字到文字频率条件。(我们对Deepseek的奖励函数了解不多)
  3. 仅在倒计时游戏任务上进行训练可能会自然地迫使模型学习解决方程式的最有效方法,因为不需要其他格式。
  4. 我们训练模型的时间不够长,在R1论文中,他们展示了训练超过8000步的视觉效果。

Tensorboard 结果

image/png

截至200步的成功推理样本

<think> First, I need to think about the numbers we have: 14, 70, and 80. To get 85, I think 80 might be a good starting point because it's right in between 85. If I subtract or add 80, it won't lead to 85 easily. Next, 70 seems to be a good intermediate number, especially to use in conjunction with 14. 
Now, I can think about different combinations using addition, subtraction and multiplication/division. If I take 80 - 70 = 10, which is not needed. So, I need to find a way to use 14. If I subtract 14 from 80, I get 66, which still isn't 85. If I subtract 70 from 80, I get 10 again, this isn't constructive yet. 
Multiplication could create numbers larger than the target 85. Let's consider division: If I divide 70 by 14, I get 5. Then if I add this result to 80, I get 85. This combination seems promising: 80 - (70 / 14) = 80 - 5 = 75, this still isn't quite right. Let me correct my approach: Let's use division and then addition. If I take 80 / (70 / 14) → 80 / 5 = 16, and then 80 - 16 = 64, still incorrect. Let me simplify: 80 / (70 / 14) → 80 / 5 = 16, and then 80 - 16 gives 64 again, let me correct: 80 - (70 / 14) → 80 - 5 = 75, but I need 85. Thus I will use the combinations in a slightly different order to insert numbers: I get 80 + (70 / 14) = 80 + 5 = 85</think>
<answer> 80 + (70 / 14) </answer>
<think> First, I need to consider how I can use the numbers 1, 21, and 90 to reach 68. I can try different combinations of operations. Let's try to use subtraction and addition. If I subtract 90 from 21, I get -69, which is too far from 68. Now, if I add 1 to 90, I get 91, which is also not right. Another way is to see what happens if I subtract 21 from 90, which gives me 69. This is closer but still not 68. Let me use multiplication instead of subtraction. 90 * 1 = 90. Now, if I try to reduce this slightly by subtracting 21, I am not getting close to 68.
Now, let's try another multiplication. I'll think of 1 * (larger value) + (smaller value). Let's try 1 * 90 + 21. This evaluates to 90 + 21 = 111, which is too much. Instead, I will try: 90 / 1 + 21 - 7. But this gives a division by 7, which seems complex. Let's re-evaluate: I'll use 1 to divide 21: 21 / 1 = 21. Now, subtract 21 from 90, and add 1, we have 90 - 21 = 69, using one...
Hey, wait on re-evaluating I will subtract 21 from 90, and include 1 to complete the solution: 90 - (21 + 1) which would be 68.</think>
<answer> 90 - (21 + 1) </answer>
<think> First, I need to consider the numbers we have: 59, 80, and 66. I need to use these numbers to get 45. Let's try to use subtraction first since it can quickly reduce large numbers.
Let's see if we can get close by subtracting or adding these numbers. If I subtract 66 from 80, I get 14. Then I need to make this 14 equal 45 when combined with 59. That gives me 14 + 59 = 73, which is too high.
Now let's try adding numbers. If I add 59 and 66, I get 125. Then I would need to subtract 80 from 125 to get 45, but 80 is too large to fit well with this.
Let me try a different approach. What if I subtract 66 from 80 first and then add 59? That would be 14 + 59 = 73, which is still not 45.
Hmm, I need to think of another combination. What if I start with 59 and then add 80 and subtract 66? That would be (59 + 80) - 66, let me calculate it. That equals 139 - 66, which is 73.
That's still not 45, so let me try rearranging it slightly: maybe 80 - 66 + 59 would be closer to 45. Calculating again: 80 - 66 is 14, then adding 59 gives 73.
This way doesn't work either. Let's flip this: maybe (59 + 66) - 80. Let me check: that's 125 - 80, which equals 45.</think>
<answer>(59 + 66) - 80</answer>
<think> First, I need to find a combination of the numbers 61, 4, 94, and 21 that can be manipulated using basic arithmetic operations to equal 50. One way to approach this is to look for numbers that can be used to reach a target number close to 50 and then adjust with the remaining numbers.
Let's try combining 94 and 21 first:
94 + 21 = 115
Now, we need to get from 115 to 50. We can try subtracting 65:
115 - 65 = 50
Now we need to find a way to get 65 using the remaining numbers 61 and 4. We can try:
61 - 4 = 57
Now we need to get from 57 to 65. Adding 8 would give us that:
57 + 8 = 65
So, we need to find a way to get 8 using 61 and 4. We can try:
61 - 94 = -36
This is negative, so let's try dividing:
61 / 94 = 0.648987...
This is not helping, so let's try multiplication and division in a different way:
(61 - 4) * (94 / 21) = 57 * 4.476... = 254.4
This is not working either. Let's try a different approach:
(94 - 61) = 33
Now we have:
33 + 21 = 54
We need to adjust from 54 to 50, so we can:
54 - 4 = 50
So, the equation becomes:
(94 - 61) + (21 - 4) = 50
</think>
<answer>(94 - 61) + (21 - 4)</answer>
<think> First, I'll consider the possibility of directly subtracting two numbers to get close to 31. If I subtract 57 from 63, I get 7. Now, if I subtract this 7 from 25, I get 25 - 7 = 18. Far from 31, so this doesn't work.
Next, let's try adding and subtracting. Dividing 63 by 25 will give a decimal, so that's unlikely to work. Now, let's see if adding and subtracting in another order could work. For example, let's try 63 - 25, and then we see this gets us 38. Now, we need a way to get to 31 from here. Subtracting 7 from 38 would work, but we can't figure out how to create that 7 from the given numbers.
Hmm, let me try to incorporate all numbers by adding and subtracting. If I try 63 + 25, I get 88. This is quite far from 31, but if 88 - 57, I get 31 exactly!
This seems to be the equation: 63 + 25 - 57 equals 31.
</think>
<answer> 63 + 25 - 57 </answer>

200步到450步之间成功推理样本

<think> We need to use the numbers 39, 16, and 55 with basic arithmetic operations to get 78. Let's try different combinations:
- 55 + 39 - 16 = 84 - 16 = 68 (too low)
- 55 + 16 - 39 = 71 - 39 = 32 (too low)
- 55 - 16 + 39 = 39 + 39 = 78 (this works)
So, the equation is 55 - 16 + 39 = 78. </think>
<answer> 55 - 16 + 39 </answer>
<think> We need to use the numbers 21, 5, and 40 with basic arithmetic operations to get 24. Let's try different combinations:
- 21 - 5 - 40 = -15
- 21 - 5 + 40 = 56
- 21 + 5 - 40 = -14
- 21 * 5 / 40 = 2.1 (not an integer, let's try another)
- 40 * 5 / 21 = 9.52 (not 24)
- 40 / 5 - 21 = 9 - 21 = -12
- 5 * 40 / 21 = 100 / 21 ~= 4.76 (not 24)
- 5 + 40 - 21 = 24 ! This works.
</think>
<answer> 5 + 40 - 21 </answer>
<think> We need to use the numbers 75, 99, 62, and 34 with basic arithmetic operations to get 72. Let's try different combinations:
- 99 - 75 + 62 - 34 = 92 - 34 = 58 (too high)
- 99 - 75 - 62 + 34 = 24 - 62 + 34 = -38 + 34 = -4 (too low)
- 99 + 75 - 62 - 34 = 174 - 96 = 78 (too high)
- 99 + 75 - 62 + 34 = 174 + 34 - 62 = 208 - 62 = 146 (too high)
- 99 - 75 + 62 + 34 = 24 + 96 = 120 (too high)
- 75 + 99 - 62 - 34 = 174 - 96 = 78 (too high)
- 75 + 99 - 62 + 34 = 174 + 34 - 62 = 208 - 62 = 146 (too high)
- 75 + 62 - 99 + 34 = 137 - 99 + 34 = 38 + 34 = 72
So, 75 + 62 - 99 + 34 equals 72.
</think>
<answer> 75 + 62 - 99 + 34 </answer>

结论

DeepSeek R1及其研究论文的发布可能是开放科学和开源开发的转折点。DeepSeek发布仅一周后,我们就能够使用GRPO和倒计时游戏重现R1学习“推理”的简单版本。虽然我们的实现侧重于特定任务而不是通用推理和收敛到非常特定的“推理”格式,但它表明该方法是有效的。

展望2025年,我们显然正处于取得更大进展的边缘。强化学习将变得更加易于访问和用户友好,更多的研究人员和开发人员将探索其潜力,但也将需要比以前和与监督微调相比更多的计算量。

我对2025年充满期待。如果您有任何问题或想法,请随时与我联系

如果这听起来很有趣,我们很乐意得到您的帮助!无论是贡献代码,还是加入Hugging Face上的讨论。

社区

注册登录发表评论