TRL 文档

PRM Trainer

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

PRM Trainer

PRM Trainer 是一个实验性的 API,随时可能更改。

概述

过程监督奖励模型 (PRM) 在 Solving math word problems with process- and outcome-based feedback 中被提出,作者是 Jonathan Uesato、Nate Kushman、Ramana Kumar、Francis Song、Noah Siegel、Lisa Wang、Antonia Creswell、Geoffrey Irving 和 Irina Higgins。

该论文的摘要如下:

最近的工作表明,要求语言模型生成推理步骤可以提高许多推理任务的性能。当超越提示时,这就提出了我们应该如何监督此类模型的问题:基于结果的方法,即监督最终结果;还是基于过程的方法,即监督推理过程本身?这些方法之间的差异自然可能不仅体现在最终答案错误上,还体现在推理错误上,后者可能难以检测,并且在教育等许多实际领域中存在问题。我们对在自然语言任务 GSM8K 上训练的基于过程和基于结果的方法进行了首次全面比较。我们发现,纯粹基于结果的监督以更少的标签监督产生相似的最终答案错误率。然而,对于正确的推理步骤,我们发现有必要使用基于过程的监督或来自模拟基于过程的反馈的学习奖励模型的监督。总的来说,我们改进了之前的最佳结果,最终答案错误率从 16.8% 降至 12.7%,最终答案正确的解决方案中的推理错误率从 14.0% 降至 3.4%。

此后训练方法由 Gaetan LopezLewis TunstallQuentin GallouédecAgustín Piqueres 贡献。

快速开始

此示例演示了如何使用 PRM 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型。我们使用来自 Math Shepherd 数据集 的逐步监督数据。您可以在此处查看数据集中的数据

以下是训练模型的脚本

# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")

training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()

使用以下命令执行脚本

accelerate launch train_prm.py

在 8 个 GPU 上分布式训练,大约需要 1 小时。

要查看 训练好的模型 的性能,您可以使用以下脚本。

from datasets import load_dataset
from transformers import pipeline

pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
    "prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
    "completions": [
        "Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
        "Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
        "Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
    ],
    "labels": [True, False, False],
}


separator = "\n"  # It's important to use the same separator as the one used during training

for idx in range(1, len(example["completions"]) + 1):
    steps = example["completions"][0:idx]
    text = separator.join((example["prompt"], *steps)) + separator  # Add a separator between the prompt and each steps
    pred_entity = pipe(text)[-1]["entity"]
    pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
    label = example["labels"][idx - 1]
    print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
Step 1  Predicted: True         Label: True
Step 2  Predicted: False        Label: False
Step 3  Predicted: False        Label: False

成功了!

预期数据集类型

PRM 需要逐步监督。数据集应包含以下列:promptcompletionslabels,其中 completions 包含推理步骤列表,labels 包含布尔值或浮点数列表,指示每个步骤的正确性。

PRMTrainer 仅支持标准数据集格式。

示例脚本

我们提供了一个示例脚本,用于使用 PRM 方法训练模型。该脚本位于 examples/scripts/prm.py

要将 PRM 脚本与 Qwen2 0.5B 模型Math Shepherd 数据集 一起使用,请运行以下命令

accelerate launch examples/scripts/prm.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/math_shepherd \
    --num_train_epochs 1 \
    --logging_steps 25 \
    --output_dir Qwen2-0.5B-Reward-Math-Sheperd

PRMTrainer

class trl.PRMTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None args: typing.Optional[trl.trainer.prm_config.PRMConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None )

参数

  • model (transformers.PreTrainedModel) — 要训练的模型,最好是 AutoModelForTokenClassification
  • args (PRMConfig) — 用于训练的参数。
  • data_collator (transformers.DataCollator) — 用于训练的数据整理器。如果未指定,将使用默认数据整理器 (DataCollatorForTokenClassification),它将成对序列数据集中的序列填充到批次中序列的最大长度。
  • train_dataset (datasets.Dataset) — 用于训练的数据集。
  • eval_dataset (datasets.Dataset) — 用于评估的数据集。
  • processing_class (PreTrainedTokenizerBaseBaseImageProcessorFeatureExtractionMixinProcessorMixin, 可选) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并与模型一起保存,以便更轻松地重新运行中断的训练或重用微调后的模型。
  • model_init (Callable[[], transformers.PreTrainedModel]) — 用于训练的模型初始化器。如果未指定,将使用默认模型初始化器。
  • compute_metrics (Callable[[transformers.EvalPrediction], dict], 可选,默认为 compute_accuracy) — 用于评估的指标。如果未指定任何指标,将使用默认指标 (compute_accuracy)。
  • callbacks (list[transformers.TrainerCallback]) — 用于训练的回调。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用于训练的优化器和调度器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 用于在计算指标之前预处理 logits 的函数。
  • peft_config (dict, 默认为 None) — 用于训练的 PEFT 配置。如果您传递 PEFT 配置,则模型将包装在 PEFT 模型中。

初始化 PRMTrainer。

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

参数

  • model_name (strNone, 可选, 默认为 None) — 模型的名称。
  • dataset_name (strNone, 可选, 默认为 None) — 用于训练的数据集名称。
  • tags (str, list[str]None, 可选, 默认为 None) — 与模型卡关联的标签。

使用 Trainer 可用的信息创建模型卡的草稿。

tokenize_row

< >

( features tokenizer step_separator max_length max_prompt_length max_completion_length train_on_last_step_only is_eval ) dict[str, list[int]]

参数

  • features (dict[str, str]) — 数据集的行,应包含键 "prompt", "completions", 和 "labels"
  • tokenizer (PreTrainedTokenizerBase) — 用于处理数据的分词器。
  • step_separator (str) — 完成步骤之间的分隔符。
  • max_length (intNone) — 序列(prompt + completion)的最大长度。如果为 None,则序列不会被截断。
  • max_prompt_length (intNone) — prompt 的最大长度。如果为 None,则 prompt 不会被截断。
  • max_completion_length (intNone) — completion 序列的最大长度。如果为 None,则 completion 序列不会被截断。
  • train_on_last_step_only (bool) — 是否仅在最后一步进行训练。如果为 True,则除了 completion 的最后一个 token 外,所有 token 的标签均为 -100
  • is_eval (bool) — 该函数是否用于对来自训练数据集或评估数据集的样本进行分词。仅当 train_on_last_step_only 设置为 True 时使用。

返回值

dict[str, list[int]]

分词后的序列,键为 "input_ids"“labels”

对数据集的行进行分词。

示例

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
...             "completions": ["11 is greater than 8.",
...                             "Hence, 9.11 > 9.8."],
...             "labels": [True, False]}
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}

PRMConfig

class trl.PRMConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False max_length: typing.Optional[int] = 1024 max_prompt_length: typing.Optional[int] = 512 max_completion_length: typing.Optional[int] = None disable_dropout: bool = True step_separator: str = '\n' train_on_last_step_only: bool = False dataset_num_proc: typing.Optional[int] = None )

参数

  • learning_rate (float, optional, defaults to 1e-5) — AdamW 优化器的初始学习率。默认值替换 TrainingArguments 的默认值。
  • max_length (intNone, 可选, 默认为 1024) — 用于截断的序列(prompt + completion)的最大长度。
  • max_prompt_length (intNone, 可选, 默认为 512) — 用于截断的 prompt 的最大长度。
  • max_completion_length (intNone, 可选, 默认为 None) — 用于截断的 completion 的最大长度。completion 是步骤的串联。
  • disable_dropout (bool, 可选, 默认为 True) — 是否禁用模型中的 dropout。
  • step_separator (str, 可选, 默认为 "\n") — 用于分隔推理过程的每个步骤的分隔符。
  • train_on_last_step_only (bool, 可选, 默认为 False) — 是否仅在最后一步进行训练。
  • dataset_num_proc (int, 可选, 默认为 None) — 用于处理数据集的进程数。

用于 PRMTrainer 的配置类。

使用 HfArgumentParser 我们可以将此类转换为 argparse 参数,这些参数可以在命令行中指定。

< > 在 GitHub 上更新