TRL 文档

Nash-MD 训练器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Nash-MD 训练器

概述

Nash-MD 由 Rémi Munos、Michal Valko、Daniele Calandriello、Mohammad Gheshlaghi Azar、Mark Rowland、Daniel Guo、Yunhao Tang、Matthieu Geist、Thomas Mésnard 和 Andrea Michi 在论文 《Nash Learning from Human Feedback》 中提出。

论文摘要如下:

从人类反馈中进行强化学习 (RLHF) 已成为使大型语言模型 (LLM) 与人类偏好对齐的主要范式。通常,RLHF 的第一步是从人类反馈中学习一个奖励模型,这些反馈通常表示为对预训练 LLM 生成的一对文本之间的偏好。随后,通过强化学习算法优化 LLM 的策略,使其最大化奖励模型。然而,当前奖励模型的一个固有局限性是它们无法完全表示人类偏好的丰富性,并且依赖于采样分布。在本研究中,我们介绍了一种使用成对人类反馈对 LLM 进行微调的替代流程。我们的方法包括首先学习一个基于给定提示的两个输入的偏好模型,然后寻求一种策略,该策略能持续生成优于任何竞争策略的响应,从而定义该偏好模型的纳什均衡。我们将此方法称为从人类反馈中进行纳什学习 (NLHF)。在表格化策略表示的背景下,我们提出了一种新颖的算法解决方案 Nash-MD,它基于镜像下降的原理。该算法产生一系列策略,最终迭代收敛到正则化的纳什均衡。此外,我们探索了策略的参数化表示,并为深度学习架构引入了梯度下降算法。为了展示我们方法的有效性,我们展示了在文本摘要任务中微调 LLM 的实验结果。我们相信 NLHF 为偏好学习和策略优化提供了一条有吸引力的途径,并有潜力推动使 LLM 与人类偏好对齐领域的发展。

此后训练方法由 Kashif RasulDaniil TiapkinPierre Ménard、Daniele Calandriello 和 Quentin Gallouédec 贡献。

快速开始

这个例子演示了如何使用 Nash-MD 方法训练一个模型。我们使用 Qwen 0.5B 模型 作为基础模型,并使用 PairRMJudge 作为评判器。我们使用来自 UltraFeedback 数据集 的提示。你可以在此处查看数据集中的提示。

以下是训练模型的脚本

# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD")
trainer = NashMDTrainer(
    model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()

使用以下命令执行脚本

accelerate launch train_nash_md.py

在 8 个 GPU 上进行分布式训练,大约需要 3 小时。

要查看 训练后模型 的性能,您可以使用 Transformers Chat CLI

$ transformers chat trl-lib/Qwen2-0.5B-NashMD
<quentin_gallouedec>:
What is the best programming language?

<trl-lib/Qwen2-0.5B-NashMD>:
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.

预期数据集类型

Nash-MD 需要一个仅含提示的数据集NashMDTrainer 支持对话式标准两种数据集格式。当提供对话式数据集时,训练器会自动将聊天模板应用于数据集。

使用技巧

使用奖励模型

除了评判器,您也可以选择使用奖励模型——请参阅 Reward Bench 获取可用的公开模型排行榜。以下代码示例展示了如何用 trl-lib/Qwen2-0.5B-Reward 模型替换评判器。

- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)

  trainer = NashMDTrainer(
      ...
-     judge=judge,
+     reward_model=reward_model,
  )

请确保 SFT 模型和奖励模型使用相同的聊天模板和分词器。否则,您可能会发现在训练过程中模型补全的评分不正确。

鼓励生成 EOS 令牌

我们可能希望模型在给定的长度内生成补全。在训练期间,模型将生成补全,其长度最多为 NashMDConfigmax_new_tokens 参数指定的最大长度。如果您想惩罚模型在达到最大长度之前未生成 EOS 令牌,可以使用 NashMDConfigmissing_eos_penalty 参数。

training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)

记录补全

为了更好地理解模型在训练过程中的行为,您可以使用 LogCompletionsCallback 定期记录样本补全。

trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)

此回调函数直接将模型生成的补全记录到 Weights & Biases。

Logged Completions

示例脚本

我们提供了一个使用 Nash-MD 方法训练模型的示例脚本。该脚本位于 examples/scripts/nash_md.py

要使用 Qwen2.5 0.5B 模型UltraFeedback 数据集 上测试在线 DPO 脚本,请运行以下命令:

python examples/scripts/nash_md.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --judge pair_rm \
    --dataset_name trl-lib/ultrafeedback-prompt \
    --learning_rate 5.0e-7 \
    --output_dir Qwen2.5-0.5B-NashMD-PairRM \
    --warmup_ratio 0.1 \
    --push_to_hub

记录的指标

记录的指标如下

  • loss/kl:模型与参考数据之间的平均 KL 散度。
  • objective/entropy:模型与参考数据的平均熵。
  • loss/score:平均强化分数损失。
  • rewards/chosen:模型补全的平均分数(根据奖励模型)。
  • rewards/rejected:混合补全的平均分数(根据奖励模型)。
  • rewards/probabilities:模型补全被选中与混合补全的平均概率(根据奖励模型或评判器)。
  • rewards/accuracies:Nash-MD 隐式奖励模型的准确率。
  • rewards/margins:被选中和混合补全之间的平均奖励边际(根据奖励模型)。
  • logps/chosen:被选中补全的平均对数概率。
  • logps/rejected:参考补全的平均对数概率。
  • val/model_contain_eos_token:模型输出包含 eos 令牌的次数。
  • val/ref_contain_eos_token:混合输出包含 eos 令牌的次数。
  • beta:控制与参考模型偏差的损失项权重的参数。通常是固定的,但可以通过向 NashMDConfig 传递一个列表来使其动态化。
  • mixture_coef:模型和参考模型的 Logit 混合系数。通常是固定的,但可以通过向 NashMDConfig 传递一个列表来使其动态化。

NashMDTrainer

class trl.NashMDTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None reward_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None judge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = None args: typing.Optional[trl.trainer.nash_md_config.NashMDConfig] = None data_collator: typing.Optional[typing.Callable] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

参数

  • model (transformers.PreTrainedModel) — 用于训练的模型,最好是 AutoModelForCausalLM
  • ref_model (PreTrainedModelWrapper) — 带有因果语言模型头的 Hugging Face transformer 模型。用于隐式奖励计算和损失。如果未提供参考模型,训练器将创建一个与待优化模型具有相同架构的参考模型。
  • reward_model (transformers.PreTrainedModel) — 用于对补全进行评分的奖励模型,最好是 AutoModelForSequenceClassification
  • judge (BasePairwiseJudge) — 用于对模型补全进行成对比较的评判器。
  • args (NashMDConfig) — 用于训练的 NashMD 配置参数。
  • data_collator (transformers.DataCollator) — 用于训练的数据整理器。如果未指定,将使用默认的数据整理器 (DPODataCollatorWithPadding),该整理器会根据批次中序列的最大长度,对成对序列的数据集进行填充。
  • train_dataset (datasets.Dataset) — 用于训练的数据集。
  • eval_dataset (datasets.Dataset) — 用于评估的数据集。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, 可选, 默认为 None) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并与模型一起保存,以便更容易地重新运行中断的训练或重用微调后的模型。
  • peft_config (dict) — 用于训练的 peft 配置。
  • compute_metrics (Callable[[EvalPrediction], dict], 可选) — 用于计算指标的函数。必须接受一个 EvalPrediction 并返回一个从字符串到指标值的字典。
  • callbacks (list[transformers.TrainerCallback]) — 用于训练的回调函数。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用于训练的优化器和调度器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 在计算指标前用于预处理 logits 的函数。

初始化 NashMDTrainer 作为 OnlineDPOConfig 的子类。

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

参数

  • resume_from_checkpoint (strbool, 可选) — 如果是 str,则为之前 Trainer 实例保存的检查点的本地路径。如果是 bool 且为 True,则加载之前 Trainer 实例在 *args.output_dir* 中保存的最后一个检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态继续。
  • trial (optuna.Trialdict[str, Any], 可选) — 试验运行或用于超参数搜索的超参数字典。
  • ignore_keys_for_eval (list[str], 可选) — 在训练期间收集评估预测时,应忽略的模型输出中的键列表(如果输出是字典)。
  • kwargs (dict[str, Any], 可选) — 用于隐藏已弃用参数的附加关键字参数

主训练入口点。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。

仅从主进程保存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

参数

  • commit_message (str, 可选, 默认为 "End of training") — 推送时提交的消息。
  • blocking (bool, 可选, 默认为 True) — 该函数是否应该仅在 git push 完成后返回。
  • token (str, 可选, 默认为 None) — 具有写入权限的令牌,用于覆盖 Trainer 的原始参数。
  • revision (str, 可选) — 要提交的 git 修订版本。默认为 "main" 分支的头部。
  • kwargs (dict[str, Any], 可选) — 传递给 ~Trainer.create_model_card 的附加关键字参数。

将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。

NashMDConfig

class trl.NashMDConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True reward_model_path: typing.Optional[str] = None judge: typing.Optional[str] = None max_new_tokens: int = 64 max_length: int = 512 temperature: float = 0.9 missing_eos_penalty: typing.Optional[float] = None beta: list = <factory> loss_type: str = 'sigmoid' dataset_num_proc: typing.Optional[int] = None disable_dropout: bool = True use_vllm: bool = False vllm_model_impl: str = 'vllm' gpu_memory_utilization: typing.Optional[float] = 0.55 ds3_gather_for_generation: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None mixture_coef: list = <factory> )

参数

  • mixture_coef (float or list[float], optional, defaults to 0.5) — 用于模型和参考模型的 Logit 混合系数。如果提供一个浮点数列表,则会为每个新周期选择混合系数,最后一个系数将用于剩余的周期。

NashMDTrainer 的配置类。

OnlineDPOConfig 的子类,我们可以使用其所有参数,并添加以下内容

< > 在 GitHub 上更新