TRL 文档

GRPO 训练器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

GRPO 训练器

概述

TRL 支持 GRPO 训练器用于训练语言模型,如论文 DeepSeekMath: 推动开放语言模型数学推理的极限 所述,作者为 Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y. K. Li, Y. Wu, Daya Guo

论文摘要如下:

数学推理因其复杂和结构化的性质,对语言模型构成了重大挑战。在本文中,我们介绍了 DeepSeekMath 7B,它在 DeepSeek-Coder-Base-v1.5 7B 的基础上继续预训练,使用了来自 Common Crawl 的 1200 亿个数学相关标记,以及自然语言和代码数据。DeepSeekMath 7B 在竞赛级别的 MATH 基准测试中取得了令人印象深刻的 51.7% 的分数,而无需依赖外部工具包和投票技术,其性能水平接近 Gemini-Ultra 和 GPT-4。DeepSeekMath 7B 在 64 个样本上进行自洽性测试,在 MATH 上达到了 60.9%。DeepSeekMath 的数学推理能力归因于两个关键因素:首先,我们通过精心设计的数据选择管道,充分利用了公开网络数据的巨大潜力。其次,我们引入了组相对策略优化(GRPO),这是近端策略优化(PPO)的一个变体,它在增强数学推理能力的同时,优化了 PPO 的内存使用。

此后训练方法由 Quentin Gallouédec 贡献。

快速入门

此示例演示了如何使用 GRPO 方法训练模型。我们使用 TLDR 数据集 中的提示(忽略完成列!)训练了一个 Qwen 0.5B Instruct 模型。你可以在此处查看数据集中的数据

以下是训练模型的脚本。

# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

使用以下命令执行脚本

accelerate launch train_grpo.py

在 8 个 GPU 上分布式训练大约需要 1 天。

深入了解 GRPO 方法

GRPO 是一种在线学习算法,这意味着它通过使用训练好的模型在训练期间生成的数据进行迭代改进。GRPO 目标的直觉是最大化生成完成的优势,同时确保模型保持接近参考策略。要理解 GRPO 的工作原理,可以将其分解为四个主要步骤:**生成完成**、**计算优势**、**估计 KL 散度**和**计算损失**。

生成完成

在每个训练步骤中,我们采样一批提示并生成一组G G 每个提示的完成(表示为oi o_i ).

计算优势

对于每个G G 序列,我们使用奖励模型计算奖励。为了与奖励模型的比较性质保持一致——通常在相同问题的输出之间进行比较的数据集上进行训练——优势的计算方式反映了这些相对比较。其标准化如下:A^i,t=rimean(r)std(r)\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}

这种方法因此得名:**群组相对策略优化 (GRPO)**。

论文 理解 R1-Zero 类似训练:批判性视角 表明,按std(r) \text{std}(\mathbf{r}) 缩放可能会导致问题级别难度偏差。你可以通过在 GRPOConfig 中设置 scale_rewards=False 来禁用此缩放。

估计 KL 散度

KL 散度使用 Schulman et al. (2020) 引入的近似器进行估计。近似器定义如下:DKL[πθπref]=πref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)logπref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)1,\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,

计算损失

目标是最大化优势,同时确保模型保持接近参考策略。因此,损失定义如下:LGRPO(θ)=1i=1Goii=1Gt=1oi[πθ(oi,tq,oi,<t)[πθ(oi,tq,oi,<t)]no gradA^i,tβDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

其中第一项代表缩放后的优势,第二项通过 KL 散度惩罚与参考策略的偏差。

注意,与 DeepSeekMath: 推动开放语言模型数学推理的极限 中原始公式相比,我们没有按1oi \frac{1}{|o_i|} 缩放,因为论文 理解 R1-Zero 类似训练:批判性视角 表明这会引入响应级别的长度偏差。更多详情请参见损失类型

注意,与 DeepSeekMath: 推动开放语言模型数学推理的极限 中的原始公式相比,我们默认使用β=0.0 \beta = 0.0 ,这意味着不使用 KL 散度项。此选择受到几项近期研究的启发(例如,Open-Reasoner-Zero: 一种在基础模型上扩展强化学习的开源方法),这些研究表明 KL 散度项对于 GRPO 训练并非必不可少。因此,将其排除已成为常见做法(例如 理解 R1-Zero 类似训练:批判性视角DAPO: 一种大规模开源 LLM 强化学习系统)。如果你希望包含 KL 散度项,可以在 GRPOConfig 中将 beta 设置为非零值。

在原始论文中,此公式被泛化以考虑每次生成后的多次更新(表示为μ \mu ,可在 GRPOConfig 中使用 num_iterations 设置),通过利用裁剪替代目标LGRPO(θ)=1i=1Goii=1Gt=1oi[min(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t)A^i,t,clip(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t),1ϵ,1+ϵ)A^i,t)βDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

其中clip(,1ϵ,1+ϵ)\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) 确保更新不会过度偏离参考策略,通过限制策略比率在以下范围之间:1ϵ 1 - \epsilon 1+ϵ 1 + \epsilon 。当μ=1 \mu = 1 (TRL中的默认值)时,裁剪的替代目标简化为原始目标。

损失类型

文献中提出了几种目标函数形式。最初,GRPO的目标函数定义如下:LGRPO(θ)=1Gi=1G1oit=1oili,t, \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},

其中li,t=πθ(oi,tq,oi,<t)[πθ(oi,tq,oi,<t)]no gradA^i,tβDKL[πθπref]. l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].

DAPO论文强调了GRPO算法在长CoT场景中样本级损失的局限性,即较长的响应受到惩罚不足,导致输出质量较差。提出的解决方案是token级归一化,它通过为单个token分配更平衡的奖励,更好地处理较长序列,而不管响应长度如何:LDAPO(θ)=1i=1Goii=1Gt=1oili,t, \mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},

此外,在理解R1-Zero类训练:一个批判性视角论文中,作者指出原始GRPO公式引入了响应长度偏差。他们表明,虽然DAPO公式减少了这种偏差,但并未完全消除。为了完全消除这种偏差,他们提出用一个常数而不是序列长度进行除法,从而得到以下公式:LDr. GRPO(θ)=1LGi=1Gt=1oili,t, \mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},

这个常数建议设置为最大完成长度。要使用此公式,请在GRPOConfig中将loss_type设置为"dr_grpo"

记录指标

  • num_tokens:迄今为止处理的令牌总数,包括提示和完成。
  • completions/mean_length:生成的完成的平均长度。
  • completions/min_length:生成的完成的最小长度。
  • completions/max_length:生成的完成的最大长度。
  • completions/mean_terminated_length:以EOS终止的生成的完成的平均长度。
  • completions/min_terminated_length:以EOS终止的生成的完成的最小长度。
  • completions/max_terminated_length:以EOS终止的生成的完成的最大长度。
  • completions/clipped_ratio:截断(裁剪)完成的比例。
  • reward/{reward_func_name}/mean:特定奖励函数的平均奖励。
  • reward/{reward_func_name}/std:特定奖励函数的奖励标准差。
  • reward:应用奖励权重后的总平均奖励。
  • reward_std:应用奖励权重后,每个批次内总奖励的标准差。
  • frac_reward_zero_std:生成批次中奖励标准差为零的样本比例,这意味着该提示的多样性很小(所有答案都正确或不正确)。
  • entropy:生成的完成中token预测的平均熵。(如果`mask_truncated_completions=True`,则排除被掩码的序列token。)
  • kl:模型与参考模型之间的平均KL散度,在生成的完成上计算。仅当`beta`不为零时记录。
  • clip_ratio/region_mean:GRPO目标被裁剪以保持在信任区域内的token(或序列,如果importance_sampling_level="sequence")概率的平均比率clip(ri,t(θ),1ϵlow,1+ϵhigh),ri,t(θ)=πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t). \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,. 值越高意味着裁剪的token越多,这限制了策略$\pi_\theta$可以改变的幅度。
  • clip_ratio/low_mean:在信任区域下限被裁剪的token(或序列,如果importance_sampling_level="sequence")概率的平均比率ri,t(θ)<1ϵlowr_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}
  • clip_ratio/low_min:在信任区域下限被裁剪的token(或序列,如果importance_sampling_level="sequence")概率的最小比率ri,t(θ)<1ϵlowr_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}
  • clip_ratio/high_mean:在信任区域上限被裁剪的token(或序列,如果importance_sampling_level="sequence")概率的平均比率ri,t(θ)>1+ϵhighr_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}
  • clip_ratio/high_max:在信任区域上限被裁剪的token(或序列,如果importance_sampling_level="sequence")概率的最大比率ri,t(θ)>1+ϵhighr_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}.

定制化

通过vLLM加速训练中的生成过程

在使用在线方法进行训练时,生成通常是主要的瓶颈。为了加速生成,您可以使用vLLM,一个用于LLM的高吞吐量、低延迟推理引擎。要启用它,首先通过以下方式安装软件包:

pip install trl[vllm]

我们支持两种在训练期间使用 vLLM 的方式:**服务器模式**和**共置模式**。

🔌 选项 1:服务器模式

在此模式下,vLLM 在单独的进程中(并使用单独的 GPU)运行,并通过 HTTP 与训练器通信。如果您有专用的 GPU 用于推理,此模式是理想选择。

  1. 启动 vLLM 服务器:

    trl vllm-serve --model <model_name>
  2. 在您的训练脚本中启用服务器模式:

    from trl import GRPOConfig
    
    training_args = GRPOConfig(
        ...,
        use_vllm=True,
        vllm_mode="server",  # default value, can be omitted
    )

请确保服务器使用的 GPU 与训练器不同,否则可能会遇到 NCCL 错误。您可以通过 `CUDA_VISIBLE_DEVICES` 环境变量指定要使用的 GPU。

🧩 选项 2:并置模式

在此模式下,vLLM 在训练器进程内运行,并与训练模型共享 GPU 内存。这避免了启动单独的服务器,可以提高 GPU 利用率,但也可能导致训练 GPU 上的内存争用。

from trl import GRPOConfig

training_args = GRPOConfig(
    ...,
    use_vllm=True,
    vllm_mode="colocate",
)

根据模型大小和训练的整体 GPU 内存要求,您可能需要调整 GRPOConfig 中的 vllm_gpu_memory_utilization 参数,以避免 GPU 利用率不足或内存不足错误。

我们提供了一个 HF Space 来帮助您根据模型配置和实验设置估算推荐的 GPU 内存利用率。只需按如下方式使用即可获得 vllm_gpu_memory_utilization 推荐

如果推荐值在您的环境中不起作用,我们建议在推荐值的基础上增加一个小的缓冲区(例如,+0.05 或 +0.1)以确保稳定性。

默认情况下,GRPO 对 vLLM 使用 MASTER_ADDR=localhostMASTER_PORT=12345,但您可以通过相应地设置环境变量来覆盖这些值。

有关更多信息,请参阅 使用 vLLM 加速训练

大规模 GRPO:在多个节点上训练 70B+ 模型

当训练像 Qwen2.5-72B 这样的大模型时,您需要一些关键的优化来使其在多个 GPU 和节点上高效且可扩展。这些优化包括

  • DeepSpeed ZeRO Stage 3:ZeRO 利用数据并行来将模型状态(权重、梯度、优化器状态)分布到多个 GPU 和 CPU 上,从而减少每个设备的内存和计算要求。由于大模型无法在单个 GPU 上运行,因此训练此类模型需要使用 ZeRO Stage 3。有关更多详细信息,请参阅 DeepSpeed 集成
  • Accelerate:Accelerate 是一个简化跨多个 GPU 和节点分布式训练的库。它提供了一个简单的 API 来启动分布式训练,并处理分布式训练的复杂性,例如数据并行、梯度累积和分布式数据加载。有关更多详细信息,请参阅 分布式训练
  • vLLM:请参阅上一节,了解如何使用 vLLM 加速生成。

以下是在多个节点上使用 GRPO 训练 70B 模型的 SLURM 脚本示例。此脚本在 4 个节点上训练模型,并使用第 5 个节点进行 vLLM 驱动的生成。

#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8

# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}"  # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}"  # Node 4 for vLLM

# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
     --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
     --num_processes 32 \
     --num_machines 4 \
     --main_process_ip ${NODELIST[0]} \
     --machine_rank $SLURM_PROCID \
     --rdzv_backend c10d \
     train_grpo.py \
     --server_ip $VLLM_NODE &

# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &

wait
import argparse

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
    args = parser.parse_args()

    # Example dataset from TLDR
    dataset = load_dataset("trl-lib/tldr", split="train")

    # Dummy reward function: count the number of unique characters in the completions
    def reward_num_unique_chars(completions, **kwargs):
        return [len(set(c)) for c in completions]

    training_args = GRPOConfig(
        output_dir="Qwen2.5-72B-GRPO",
        per_device_train_batch_size=4,
        bf16=True,
        gradient_checkpointing=True,
        use_vllm=True,
        vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."),  # from ip-X-X-X-X to X.X.X.X
    )

    trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
    trainer.train()

if __name__=="__main__":
    main()

使用自定义奖励函数

GRPOTrainer 支持使用自定义奖励函数而不是密集奖励模型。为确保兼容性,您的奖励函数必须满足以下要求

  1. 输入参数:

    • 函数必须接受以下作为关键字参数

      • prompts(包含提示),
      • completions(包含生成的补全),
      • completions_ids(包含标记化的补全),
      • trainer_stateTrainerState):训练器的当前状态。这可用于实现动态奖励函数,例如课程学习,其中奖励根据训练进度进行调整。
      • 数据集可能具有的所有列名(prompt 除外)。例如,如果数据集包含名为 ground_truth 的列,则函数将以 ground_truth 作为关键字参数调用。

      满足此要求的最简单方法是在函数签名中使用 **kwargs

    • 根据数据集格式,输入将有所不同

      • 对于 标准格式promptscompletions 将是字符串列表。
      • 对于 对话格式promptscompletions 将是消息字典列表。
  2. 返回值:函数必须返回一个浮点数列表。每个浮点数代表与单个补全对应的奖励。

示例 1:奖励更长的补全

下面是一个标准格式的奖励函数示例,它奖励更长的补全

def reward_func(completions_ids, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
    return [float(len(ids)) for ids in completions_ids]

您可以按如下方式测试它

>>> prompts = ["The sky is", "The sun is"]  # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."]  # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]

示例 1.1:奖励更长的补全(基于字符数)

与上一个示例相同,但这次奖励函数基于字符数而不是标记。

def reward_func(completions, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of character count)."""
    return [float(len(completion)) for completion in completions]

您可以按如下方式测试它

>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]  # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]

示例 2:奖励具有特定格式的补全

下面是一个奖励函数示例,它检查补全是否具有特定格式。此示例的灵感来自论文 DeepSeek-R1:通过强化学习激励 LLM 的推理能力 中使用的*格式奖励*函数。它专为对话格式设计,其中提示和补全由结构化消息组成。

import re

def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

您可以按如下方式测试此功能

>>> prompts = [
...     [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
...     [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
...     [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
...     [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]

示例 3:根据参考奖励补全

以下是一个奖励函数示例,它检查补全是否正确。此示例的灵感来自论文 DeepSeek-R1:通过强化学习激励 LLM 的推理能力 中使用的*准确性奖励*函数。此示例专为 标准格式 设计,其中数据集包含名为 ground_truth 的列。

import re

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]

您可以按如下方式测试此功能

>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]

示例 4:多任务奖励函数

以下是在 GRPOTrainer 中使用多个奖励函数的示例。在此示例中,我们定义了两个特定于任务的奖励函数:math_reward_funccoding_reward_funcmath_reward_func 根据正确性奖励数学问题,而 coding_reward_func 根据解决方案是否有效奖励编码问题。

from datasets import Dataset
from trl import GRPOTrainer

# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
    [
        {"prompt": "What is 2+2?", "task": "math"},
        {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
        {"prompt": "What is 3*4?", "task": "math"},
        {"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
    ]
)

# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "math":
            # Calculate math-specific reward
            correct = check_math_solution(prompt, completion)
            reward = 1.0 if correct else -1.0
            rewards.append(reward)
        else:
            # Return None for non-math tasks
            rewards.append(None)
    return rewards

# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "coding":
            # Calculate coding-specific reward
            works = test_code_solution(prompt, completion)
            reward = 1.0 if works else -1.0
            rewards.append(reward)
        else:
            # Return None for non-coding tasks
            rewards.append(None)
    return rewards

# Use both task-specific reward functions
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=[math_reward_func, coding_reward_func],
    train_dataset=dataset,
)

trainer.train()

在此示例中,math_reward_funccoding_reward_func 旨在与包含数学和编码问题的混合数据集一起使用。数据集中的 task 列用于确定将哪个奖励函数应用于每个问题。如果数据集中没有与样本相关的奖励函数,则奖励函数将返回 NoneGRPOTrainer 将继续使用有效的函数和任务。这允许 GRPOTrainer 处理具有不同适用性的多个奖励函数。

请注意,GRPOTrainer 将忽略奖励函数返回的 None 奖励,只考虑相关函数返回的奖励。这确保模型在相关任务上进行训练,并忽略没有相关奖励函数的任务。

将奖励函数传递给训练器

要使用您的自定义奖励函数,请按如下方式将其传递给 GRPOTrainer

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=reward_func,
    ...,
)

如果您有多个奖励函数,可以将其作为列表传递

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    ...,
)

奖励将计算为每个函数的奖励之和,如果配置中提供了 reward_weights,则为加权和。

请注意,GRPOTrainer 支持不同类型的多个奖励函数。有关更多详细信息,请参阅参数文档。

视觉语言模型 (VLM) 训练

GRPO 支持在包含文本和图像的多模态数据集上训练视觉语言模型 (VLM)。

支持的模型

已测试的型号:

  • Gemma3 — 例如,google/gemma-3-4b-it
  • LLaVA-NeXT — 例如,llava-hf/llava-v1.6-mistral-7b-hf
  • Qwen2-VL — 例如,Qwen/Qwen2-VL-2B-Instruct
  • Qwen2.5-VL — 例如,Qwen/Qwen2.5-VL-3B-Instruct
  • SmolVLM2 — 例如,HuggingFaceTB/SmolVLM2-2.2B-Instruct
不保证与所有 VLM 兼容。如果您认为某个模型应该得到支持,请随时在 GitHub 上提出问题,或者更好的是,提交包含所需更改的拉取请求。

快速入门

使用 grpo_vlm.py 对 VLM 进行微调。在 lmms-lab/multimodal-open-r1-8k-verified 上训练的示例命令

accelerate launch \
  --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
  examples/scripts/grpo_vlm.py \
  --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
  --output_dir grpo-Qwen2.5-VL-3B-Instruct \
  --learning_rate 1e-5 \
  --gradient_checkpointing \
  --torch_dtype bfloat16 \
  --max_prompt_length 2048 \
  --max_completion_length 1024 \
  --use_vllm \
  --vllm_mode colocate \
  --use_peft \
  --lora_target_modules "q_proj", "v_proj" \
  --log_completions

配置提示

如果图像标记被截断,VLM 训练可能会失败。强烈建议通过将 `max_prompt_length` 设置为 `None` 来禁用截断。
  • 在视觉-语言投影层上使用 LoRA
  • 启用 4 位量化以减少内存使用
  • VLM 是内存密集型的——从较小的批次大小开始
  • 大多数模型与 vLLM 兼容(servercolocate 模式)

数据集格式

每个训练样本应包括

  • prompt:通过处理器聊天模板格式化的文本
  • image:单个图像(PIL 或 NumPy 数组)

训练器通过模型的图像处理器自动处理图像到张量的转换。

GRPOTrainer

class trl.GRPOTrainer

< >

( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: typing.Optional[trl.trainer.grpo_config.GRPOConfig] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.processing_utils.ProcessorMixin, NoneType] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )

参数

  • model (Union[str, PreTrainedModel]) — 要训练的模型。可以是以下任一类型:

    • 字符串:Hugging Face 模型库中预训练模型的*模型 ID*,或包含使用 save_pretrained 保存的模型权重的*目录*路径,例如 './my_model_directory/'。模型使用 from_pretrainedargs.model_init_kwargs 中的关键字参数加载。
    • PreTrainedModel 对象。仅支持因果语言模型。
  • reward_funcs (Union[RewardFunc, list[RewardFunc]]) — 用于计算奖励的奖励函数。为了计算奖励,我们将所有奖励函数与提示和补全一起调用并求和。可以是以下任一类型:

    • 单个奖励函数,例如:

      • 字符串:Hugging Face 模型库中预训练模型的*模型 ID*,或包含使用 save_pretrained 保存的模型权重的*目录*路径,例如 './my_model_directory/'。模型使用 from_pretrainednum_labels=1 以及 args.model_init_kwargs 中的关键字参数加载。

      • PreTrainedModel 对象:仅支持序列分类模型。

      • 自定义奖励函数:该函数提供提示和生成的补全,以及数据集中的任何附加列。它应该返回一个奖励列表。当奖励不适用于这些样本时,自定义奖励函数也可以返回 None。这对于多任务训练非常有用,其中不同的奖励函数适用于不同类型的样本。当奖励函数为样本返回 None 时,该奖励函数将从该样本的奖励计算中排除。有关更多详细信息,请参阅 使用自定义奖励函数

        训练器的状态也传递给奖励函数。训练器的状态是 TrainerState 的实例,可以通过访问奖励函数签名的 trainer_state 参数来访问。

    • 奖励函数列表,其中每个项都可以独立地是上述任何类型。允许列表中混合不同类型(例如,字符串模型 ID 和自定义奖励函数)。

  • args (GRPOConfig可选,默认为 None) — 此训练器的配置。如果为 None,则使用默认配置。
  • train_dataset (DatasetIterableDataset) — 用于训练的数据集。它必须包含一个 "prompt" 列。数据集中任何附加列都将被忽略。样本的格式可以是:

    • 标准格式:每个样本包含纯文本。
    • 对话格式:每个样本包含结构化消息(例如,角色和内容)。
  • eval_dataset (Dataset, IterableDatasetdict[str, Union[Dataset, IterableDataset]]) — 用于评估的数据集。它必须满足与 train_dataset 相同的要求。
  • processing_class (PreTrainedTokenizerBaseProcessorMixin可选,默认为 None) — 用于处理数据的处理类。填充侧必须设置为“左”。如果为 None,则从模型的名称中使用 from_pretrained 加载处理类。必须设置填充标记 tokenizer.pad_token。如果处理类未设置填充标记,则 tokenizer.eos_token 将用作默认值。
  • reward_processing_classes (Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]可选,默认为 None) — 与 reward_funcs 中指定的奖励函数对应的处理类。可以是以下任一类型:

    • 单个处理类:当 reward_funcs 只包含一个奖励函数时使用。
    • 处理类列表:必须与 reward_funcs 中奖励函数的顺序和长度匹配。如果设置为 None,或者列表中与 PreTrainedModel 对应的元素为 None,则模型的 tokenizer 会自动使用 from_pretrained 加载。对于 reward_funcs 中是自定义奖励函数(而不是 PreTrainedModel)的元素,reward_processing_classes 中对应的条目将被忽略。
  • callbacks (TrainerCallback 列表,可选,默认为 None) — 自定义训练循环的回调列表。这将添加到 此处 详述的默认回调列表中。

    如果要删除使用的默认回调之一,请使用 remove_callback 方法。

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]可选,默认为 (None, None)) — 包含要使用的优化器和调度器的元组。默认为模型上的 AdamW 实例和由 args 控制的 get_linear_schedule_with_warmup 提供的调度器。
  • peft_config (~peft.PeftConfig可选,默认为 None) — 用于包装模型的 PEFT 配置。如果为 None,则不包装模型。

用于分组相对策略优化(GRPO)方法的训练器。该算法最初是在论文 DeepSeekMath:在开放语言模型中推动数学推理能力的极限 中提出的。

示例

from datasets import load_dataset
from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")


def reward_func(completions, **kwargs):
    # Dummy reward function that rewards completions with more unique letters.
    return [float(len(set(completion))) for completion in completions]


trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    train_dataset=dataset,
)

trainer.train()

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

参数

  • resume_from_checkpoint (strbool可选) — 如果是 str,则为先前 Trainer 实例保存的检查点的本地路径。如果为 bool 且等于 True,则加载先前 Trainer 实例在 args.output_dir 中保存的最新检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态恢复。
  • trial (optuna.Trialdict[str, Any]可选) — 用于超参数搜索的试验运行或超参数字典。
  • ignore_keys_for_eval (list[str]可选) — 模型输出(如果是字典)中应在训练期间收集预测以进行评估时忽略的键列表。
  • kwargs (dict[str, Any]可选) — 用于隐藏已弃用参数的附加关键字参数

主训练入口点。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。

仅从主进程保存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

参数

  • commit_message (str可选,默认为 "End of training") — 推送时提交的消息。
  • blocking (bool可选,默认为 True) — 函数是否应仅在 git push 完成后返回。
  • token (str可选,默认为 None) — 具有写入权限的令牌,用于覆盖 Trainer 的原始 args。
  • revision (str可选) — 要提交的 Git 修订版本。默认为“main”分支的头部。
  • kwargs (dict[str, Any]可选) — 传递给 ~Trainer.create_model_card 的附加关键字参数。

将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。

GRPOConfig

class trl.GRPOConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Union[dict, str, NoneType] = None disable_dropout: bool = False max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 max_completion_length: typing.Optional[int] = 256 ds3_gather_for_generation: bool = True shuffle_dataset: typing.Optional[bool] = True generation_batch_size: typing.Optional[int] = None steps_per_generation: typing.Optional[int] = None temperature: float = 1.0 top_p: float = 1.0 top_k: typing.Optional[int] = None min_p: typing.Optional[float] = None generation_kwargs: typing.Optional[dict] = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: typing.Optional[str] = None use_vllm: bool = False vllm_server_base_url: typing.Optional[str] = None vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_guided_decoding_regex: typing.Optional[str] = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: typing.Optional[float] = None epsilon_high: typing.Optional[float] = None importance_sampling_level: str = 'token' reward_weights: typing.Optional[list[float]] = None scale_rewards: bool = True loss_type: str = 'bnpo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 use_liger_loss: bool = False log_completions: bool = False num_completions_to_print: typing.Optional[int] = None wandb_log_unique_prompts: typing.Optional[bool] = False )

控制模型和参考模型的参数

  • model_init_kwargs (str, dict[str, Any]None, 可选, 默认为 None) — from_pretrained 的关键字参数,当 GRPOTrainermodel 参数以字符串形式提供时使用。
  • disable_dropout (bool, 可选, 默认为 False) — 是否在模型中禁用 dropout。这对于使用参考模型进行训练很有用,因为它可以防止模型为相同输入生成不同的 logprobs。

控制数据预处理的参数

  • remove_unused_columns (bool, 可选, 默认为 False) — 是否仅保留数据集中 "prompt" 列。如果您使用自定义奖励函数,并且该函数除了 "prompts""completions" 之外还需要其他列,则应将其设置为 False
  • max_prompt_length (intNone, 可选, 默认为 512) — 提示的最大长度。如果提示长度超过此值,将从左侧截断。
  • num_generations (intNone, 可选, 默认为 8) — 每个提示的生成样本数。有效批处理大小(num_processes * per_device_batch_size * gradient_accumulation_steps)必须能被此值整除。
  • max_completion_length (intNone, 可选, 默认为 256) — 生成完成的最大长度。
  • ds3_gather_for_generation (bool, 可选, 默认为 True) — 此设置适用于 DeepSpeed ZeRO-3。如果启用,将收集策略模型权重以进行生成,从而提高生成速度。但是,禁用此选项可以训练超出单个 GPU 显存容量的模型,尽管代价是生成速度较慢。禁用此选项与 vLLM 生成不兼容。
  • shuffle_dataset (bool, 可选, 默认为 True) — 是否打乱训练数据集。

控制生成的参数

  • generation_batch_size — (intNone, 可选, 默认为 None): 用于生成的批处理大小。如果为 None,则默认为有效训练批处理大小:per_device_train_batch_size * num_processes * steps_per_generation。换句话说,每个优化步骤处理一个生成批次。与 steps_per_generation 互斥。
  • steps_per_generation — (intNone, 可选, 默认为 None): 每次生成步数。如果为 None,则默认为 gradient_accumulation_steps。与 generation_batch_size 互斥。
  • temperature (float, 默认为 1.0) — 采样的温度。温度越高,完成度越随机。
  • top_p (float, 可选, 默认为 1.0) — 控制要考虑的最高概率标记的累积概率的浮点数。必须在 (0, 1] 范围内。设置为 1.0 以考虑所有标记。
  • top_k (intNone, 可选, 默认为 None) — 保留用于 top-k 过滤的最高概率词汇标记数量。如果为 None,则禁用 top-k 过滤,并考虑所有标记。
  • min_p (floatNone, 可选, 默认为 None) — 最小标记概率,将按最可能标记的概率进行缩放。它必须是 0.01.0 之间的值。典型值在 0.01-0.2 范围内。
  • repetition_penalty (float, 可选, 默认为 1.0) — 根据新标记是否出现在提示和已生成文本中来惩罚新标记的浮点数。值 > 1.0 鼓励模型使用新标记,而值 < 1.0 鼓励模型重复标记。
  • use_transformers_paged (bool, 可选, 默认为 False) — 是否使用 transformers 的分页实现进行生成。如果设置为 True,将使用 transformers 的分页实现进行生成,而不是默认的填充实现。此参数仅在 use_vllm 设置为 False 时有效。
  • cache_implementation (strNone, 可选, 默认为 None) — 当 use_vllm 设置为 False 时,用于更快生成缓存方法的实现。
  • generation_kwargs (dict[str, Any]None, 可选, 默认为 None) — 采样完成时传递给 GenerationConfig(如果使用 transformers)或 SamplingParams(如果使用 vLLM)的附加关键字参数。这可用于进一步自定义生成行为,例如设置 supress_tokensnum_beams 等。如果它包含与其它生成参数(如 min_ptop_p 等)冲突的键,它们将覆盖这些参数。

vLLM 支持的生成加速控制参数

  • use_vllm (bool, 可选, 默认为 False) — 是否使用 vLLM 生成完成。如果设置为 True,训练器将使用 vLLM 进行生成,而不是默认的 model.generate()。需要安装 vllm
  • vllm_mode (str, 可选, 默认为 "server") — 当 use_vllm 设置为 True 时,用于 vLLM 集成的模式。必须是 "server""colocate" 之一。

    • "server": 训练器将生成请求发送到单独的 vLLM 服务器。请确保 TRL vLLM 服务器正在运行(使用 trl vllm-serve 启动)。
    • "colocate": vLLM 将在同一进程中运行并共享训练 GPU。这避免了对单独服务器的需求,但可能会导致与训练的资源争用。
  • vllm_guided_decoding_regex (strNone, 可选, 默认为 None) — vLLM 引导式解码的正则表达式。如果为 None(默认),则禁用引导式解码。

控制 vLLM 服务器的参数(仅当 `vllm_mode` 为 `"server"` 时使用)

  • vllm_server_base_url (strNone, 可选, 默认为 None) — vLLM 服务器的基本 URL(例如,"https://:8000")。如果提供此参数,则 vllm_server_hostvllm_server_port 将被忽略。
  • vllm_server_host (str, 可选, 默认为 "0.0.0.0") — 要连接的 vLLM 服务器主机。如果提供了 vllm_server_base_url,则忽略此参数。
  • vllm_server_port (int, 可选, 默认为 8000) — 要连接的 vLLM 服务器端口。如果提供了 vllm_server_base_url,则忽略此参数。
  • vllm_server_timeout (float, 可选, 默认为 240.0) — 等待 vLLM 服务器启动的总超时时长(秒)。如果在超时后服务器仍未启动,则会引发 ConnectionError

控制共置 vLLM 执行的参数(仅当 `vllm_mode` 为 `"colocate"` 时使用)

  • vllm_gpu_memory_utilization (float, 可选, 默认为 0.3) — 控制 vLLM 的 GPU 内存利用率。此设置仅在 vllm_mode 设置为 "colocate" 时适用。如果您使用 vllm_mode="server",则必须在通过 --vllm_gpu_memory_utilization 标志启动 vLLM 服务器时单独传递此参数。
  • vllm_tensor_parallel_size (int, 可选, 默认为 1) — 控制 vLLM 的张量并行大小。此设置仅在 vllm_mode 设置为 "colocate" 时适用。如果您使用 vllm_mode="server",则必须在通过 --vllm_tensor_parallel_size 标志启动 vLLM 服务器时单独传递此参数。
  • vllm_model_impl (str, 可选, 默认为 "vllm") — 用于 vLLM 的模型实现。必须是 "transformers""vllm" 之一。"transformers":使用 transformers 后端进行模型实现。"vllm":使用 vllm 库进行模型实现。

控制训练的参数

  • beta (float, 可选, 默认为 0.0) — KL 系数。如果为 0.0(默认),则不加载参考模型,从而减少内存使用并提高训练速度。
  • num_iterations (int, 可选, 默认为 1) — 每批次的迭代次数(在算法中表示为 μ)。
  • epsilon (float, 可选, 默认为 0.2) — 用于裁剪的 Epsilon 值。
  • delta — (floatNone, 可选, 默认为 None): 当设置为浮点数时,启用两边 GRPO 损失中的上限裁剪。如果为 None(默认),则使用标准 GRPO 裁剪。建议在启用时大于 1 + ε。此方法在INTELLECT-2 技术报告中引入。
  • epsilon_high (floatNone, 可选, 默认为 None) — 裁剪的上限 epsilon 值。如果未指定,则默认为与参数 epsilon 中指定的下限相同的值。DAPO 论文推荐使用 0.28
  • importance_sampling_level (str, 可选, 默认为 "token") — 控制重要性采样比率是在 "token" 级别还是 "sequence" 级别计算。"token" 保留原始的每令牌对数概率比率(每个令牌一个权重)。"sequence" 对有效令牌的对数概率比率进行平均,为每个序列生成一个比率。GSPO 论文表明,序列级采样通常会带来更稳定的训练和更好的与序列级奖励对齐。
  • reward_weights (list[float]None, 可选, 默认为 None) — 每个奖励函数的权重。必须与奖励函数的数量匹配。如果为 None,所有奖励都以 1.0 的权重平均加权。
  • scale_rewards (bool, 可选, 默认为 True) — 是否通过将奖励除以其标准差来缩放奖励。如果为 True(默认),奖励将按标准差归一化,确保它们具有单位方差。如果为 False,则不应用缩放。Dr. GRPO 论文建议不缩放奖励,因为按标准差缩放会引入问题级难度偏差。
  • loss_type (str, 可选, 默认为 "bnpo") — 指定要使用的损失公式。支持的值有:

    • "grpo":通过对序列长度进行归一化来聚合令牌级损失。不推荐,因为它会产生长度偏差——这种方法倾向于在具有正优势时偏好较短的补全,而在具有负优势时偏好较长的补全。
    • "bnpo":通过对本地批次中活跃令牌的数量进行归一化来聚合令牌级损失。请注意,归一化仅在本地批次上执行,因此结果可能会因本地批次大小而略有不同,尽管有效批次大小是恒定的。当使用 per_device_train_batch_size==1 时,损失等效于 GRPO 损失。
    • "dr_grpo":通过全局常数进行归一化来聚合令牌级损失。此方法在Dr. GRPO 论文中引入,以消除长度偏差。常数的值对应于 max_completion_length
  • mask_truncated_completions (bool, 可选, 默认为 False) — 启用后,截断的补全将从损失计算中排除,防止它们被错误地惩罚并在训练期间引入噪声。根据DAPO 论文,这是训练稳定性的良好实践。
  • sync_ref_model (bool, 可选, 默认为 False) — 是否每 ref_model_sync_steps 步使用 ref_model_mixup_alpha 参数将参考模型与活跃模型同步。此同步源自TR-DPO 论文
  • ref_model_mixup_alpha (float, 可选, 默认为 0.6) — 来自TR-DPO 论文的 α 参数,它控制当前策略和先前参考策略在更新期间的混合。参考策略根据以下公式更新:π_ref = α * π_θ + (1 - α) * π_ref_prev。要使用此参数,必须设置 sync_ref_model=True
  • ref_model_sync_steps (int, 可选, 默认为 512) — 来自TR-DPO 论文的 τ 参数,它决定了当前策略与参考策略同步的频率。要使用此参数,必须设置 sync_ref_model=True
  • top_entropy_quantile (float, 可选, 默认为 1.0) — 来自Beyond the 80/20 Rule的 ρ 参数。只保留每个序列位置上概率分布熵的最高 ρ 分位数令牌在策略损失项中,从而改善结果。范围:[0.0-1.0]。值为 0.0 屏蔽除最高熵令牌外的所有令牌;1.0 保留所有令牌。论文推荐值为 0.2。如果与 mask_truncated_completions=True 一起使用,则只考虑非截断补全中的令牌。
  • use_liger_loss (bool, 可选, 默认为 False) — 是否使用 Liger GRPO 损失。

控制日志记录的参数

  • log_completions (bool, 可选, 默认为 False) — 是否每 logging_steps 步记录一组 (提示, 补全) 对。如果安装了 rich,它会打印该样本。如果启用了 wandb 日志记录,它会将其记录到 wandb
  • num_completions_to_print (intNone, 可选, 默认为 None) — 要使用 rich 打印的补全数量。如果为 None,则记录所有补全。
  • wandb_log_unique_prompts (bool, 可选, 默认为 False) — 是否在 wandb 中记录唯一提示。如果为 True,则只记录唯一提示。如果为 False,则记录所有提示。

GRPOTrainer 的配置类。

此类仅包含 GRPO 训练特有的参数。有关训练参数的完整列表,请参阅 TrainingArguments 文档。请注意,此类的默认值可能与 TrainingArguments 中的默认值不同。

使用 HfArgumentParser,我们可以将此类别转换为可在命令行中指定的 argparse 参数。

< > 在 GitHub 上更新