TRL 文档
Nash-MD 训练器
并获得增强的文档体验
开始使用
Nash-MD 训练器
概述
Nash-MD 是 Rémi Munos、Michal Valko、Daniele Calandriello、Mohammad Gheshlaghi Azar、Mark Rowland、Daniel Guo、Yunhao Tang、Matthieu Geist、Thomas Mésnard 和 Andrea Michi 在论文 Nash Learning from Human Feedback 中提出的。
该论文的摘要如下:
从人类反馈中强化学习 (RLHF) 已成为将大型语言模型 (LLM) 与人类偏好对齐的主要范式。通常,RLHF 涉及从人类反馈中学习奖励模型的初始步骤,人类反馈通常表示为对预训练 LLM 生成的成对文本生成结果的偏好。随后,通过强化学习算法优化 LLM 的策略,使其奖励模型最大化,从而对 LLM 的策略进行微调。然而,当前奖励模型的一个固有局限性在于,它们无法完全代表人类偏好的丰富性及其对采样分布的依赖性。在本研究中,我们介绍了一种使用成对人类反馈微调 LLM 的替代流程。我们的方法包括首先学习一个偏好模型,该模型以给定提示的两个输入为条件,然后寻求一种策略,该策略始终生成优于任何竞争策略生成的响应,从而定义此偏好模型的纳什均衡。我们将这种方法称为从人类反馈中进行纳什学习 (NLHF)。在表格策略表示的背景下,我们提出了一种基于镜像下降原理的新颖算法解决方案 Nash-MD。此算法生成一系列策略,最后一次迭代收敛到正则化纳什均衡。此外,我们探索了策略的参数化表示,并为深度学习架构引入了梯度下降算法。为了证明我们方法的有效性,我们展示了涉及微调用于文本摘要任务的 LLM 的实验结果。我们相信 NLHF 为偏好学习和策略优化提供了一条引人注目的途径,有可能推进 LLM 与人类偏好对齐领域的发展。
此后训练方法由 Kashif Rasul 和 Daniil Tiapkin、Pierre Ménard、Daniele Calandriello 和 Quentin Gallouédec 贡献。
快速入门
此示例演示了如何使用 Nash-MD 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型,并使用 PairRMJudge 作为评判器。我们使用来自 UltraFeedback 数据集 的提示。您可以在此处查看数据集中的提示
以下是训练模型的脚本
# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10)
trainer = NashMDTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
使用以下命令执行脚本
accelerate launch train_nash_md.py
在 8 个 GPU 上分布式训练大约需要 3 小时。
要查看 训练好的模型 的性能,您可以使用 Transformers Chat CLI。
$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<quentin_gallouedec>:
What is the best programming language?
<trl-lib/Qwen2-0.5B-NashMD>:
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
预期数据集类型
Nash-MD 需要一个仅提示数据集。NashMDTrainer 支持对话式和标准数据集格式。当提供对话式数据集时,训练器将自动将聊天模板应用于数据集。
使用技巧
使用奖励模型
您可以选择使用奖励模型而不是评判器 — 请参阅 Reward Bench 以获取您可以使用的公共模型的排行榜。以下代码示例展示了如何用 trl-lib/Qwen2-0.5B-Reward 模型替换评判器
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
trainer = NashMDTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
确保 SFT 模型和奖励模型使用相同的聊天模板和相同的分词器。否则,您可能会发现模型完成在训练期间得分不正确。
鼓励 EOS 令牌生成
我们可能希望模型在给定长度内生成完成。在训练期间,模型将生成完成,直到 NashMDConfig 的 max_new_tokens
参数中指定的最大长度。如果您想惩罚模型在达到最大长度之前未生成 EOS 令牌,您可以使用 NashMDConfig 的 missing_eos_penalty
参数
training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
记录完成
为了更好地了解模型在训练期间的行为,您可以定期使用 LogCompletionsCallback 记录样本完成。
trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
此回调直接将模型生成的完成记录到 Weights & Biases。
示例脚本
我们提供了一个示例脚本,用于使用 Nash-MD 方法训练模型。该脚本位于 examples/scripts/nash_md.py
中
要使用 Qwen2.5 0.5B 模型 在 UltraFeedback 数据集 上测试在线 DPO 脚本,请运行以下命令
python examples/scripts/nash_md.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --judge pair_rm \ --dataset_name trl-lib/ultrafeedback-prompt \ --learning_rate 5.0e-7 \ --logging_steps 25 \ --output_dir Qwen2.5-0.5B-NashMD-PairRM \ --warmup_ratio 0.1 \ --push_to_hub
记录的指标
记录的指标如下:
loss/kl
:模型和参考数据之间的平均 KL 散度。objective/entropy
:模型和参考数据的平均熵。loss/score
:平均强化分数损失。rewards/chosen
:模型完成的平均分数(根据奖励模型)。rewards/rejected
:混合完成的平均分数(根据奖励模型)。rewards/probabilities
:模型完成选择与混合完成的平均概率(根据奖励模型或评判器)。rewards/accuracies
:Nash-MD 隐式奖励模型的准确率。rewards/margins
:选择和混合完成之间平均奖励边际(根据奖励模型)。logps/chosen
:选择完成的平均对数概率。logps/rejected
:参考完成的平均对数概率。val/model_contain_eos_token
:模型输出包含 eos 令牌的次数。val/ref_contain_eos_token
:混合输出包含 eos 令牌的次数。beta
:控制代表偏离参考模型的损失项权重的参数。通常是固定的,但可以通过将列表传递给 NashMDConfig 使其动态化。mixture_coef
:模型和参考模型的 Logit 混合系数。通常是固定的,但可以通过将列表传递给 NashMDConfig 使其动态化。
NashMDTrainer
class trl.NashMDTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None reward_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None judge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = None args: typing.Optional[trl.trainer.nash_md_config.NashMDConfig] = None data_collator: typing.Optional[typing.Callable] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
参数
- model (
transformers.PreTrainedModel
) — 要训练的模型,最好是AutoModelForCausalLM
。 - ref_model (
PreTrainedModelWrapper
) — Hugging Face transformers 模型,带有因果语言建模头。用于隐式奖励计算和损失。如果未提供参考模型,训练器将创建一个与要优化的模型具有相同架构的参考模型。 - reward_model (
transformers.PreTrainedModel
) — 用于对补全结果进行评分的奖励模型,最好是AutoModelForSequenceClassification
。 - judge (
BasePairwiseJudge
) — 用于模型补全结果的成对比较的评判器。 - args (
NashMDConfig
) — 用于训练的 NashMD 配置参数。 - data_collator (
transformers.DataCollator
) — 用于训练的数据收集器。如果未指定,将使用默认数据收集器 (`DPODataCollatorWithPadding`),它将根据成对序列的数据集,将序列填充到批次中最长序列的长度。 - train_dataset (
datasets.Dataset
) — 用于训练的数据集。 - eval_dataset (
datasets.Dataset
) — 用于评估的数据集。 - processing_class (
PreTrainedTokenizerBase
或BaseImageProcessor
或FeatureExtractionMixin
或ProcessorMixin
, optional) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并与模型一起保存,以便更容易地重新运行中断的训练或重用微调模型。 - peft_config (
dict
) — 用于训练的 peft 配置。 - compute_metrics (
Callable[[EvalPrediction], dict]
, optional) — 用于计算指标的函数。必须接受一个EvalPrediction
并返回一个从字符串到指标值的字典。 - callbacks (
list[transformers.TrainerCallback]
) — 用于训练的回调。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用于训练的优化器和调度器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 用于在计算指标之前预处理 logits 的函数。
将 NashMDTrainer 初始化为 OnlineDPOConfig 的子类。
NashMDConfig
class trl.NashMDConfig
< 源代码 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False reward_model_path: typing.Optional[str] = None judge: typing.Optional[str] = None max_new_tokens: int = 64 max_length: int = 512 temperature: float = 0.9 missing_eos_penalty: typing.Optional[float] = None beta: list = <factory> loss_type: str = 'sigmoid' dataset_num_proc: typing.Optional[int] = None disable_dropout: bool = True use_vllm: bool = False ds3_gather_for_generation: bool = True mixture_coef: list = <factory> )
为 NashMDTrainer 提供的配置类。
作为 OnlineDPOConfig 的子类,我们可以使用其所有参数并添加以下参数