TRL 文档
XPO 训练器
并获得增强的文档体验
开始使用
XPO 训练器
概览
探索性偏好优化 (Exploratory Preference Optimization, XPO) 在论文 《探索性偏好优化:利用隐式 Q*-近似实现样本高效的 RLHF》 (Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF) 中被提出,作者为 Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, Corby Rosset, Ahmed Awadallah, and Alexander Rakhlin。它是一种简单的在线偏好调优方法,基于 DPO 损失和奖励模型 (RM)。XPO 通过增加探索奖励来增强 DPO 目标,使该方法能够探索初始模型和人类反馈数据支持范围之外的内容。
论文摘要如下:
基于人类反馈的强化学习 (RLHF) 已成为语言模型对齐的核心工具。我们考虑 RLHF 中的在线探索,它利用与人类或 AI 反馈的交互式访问,通过有意鼓励模型产生多样化、信息量最大的响应。通过允许 RLHF 自信地偏离预训练模型,在线探索为实现新颖、可能超越人类的能力提供了可能性,但由于直接调整现有强化学习技术存在计算和统计瓶颈,其作为语言模型训练范式的全部潜力尚未实现。我们提出了一种新的 RLHF 在线探索算法,即探索性偏好优化 (XPO),它简单实用——只需对(在线)直接偏好优化 (DPO; Rafailov et al., 2023) 做一行代码的更改——却拥有已知最强的可证明保证和有前景的经验性能。XPO 通过一种新颖且有原则的探索奖励来增强 DPO 目标,使算法能够探索初始模型和人类反馈数据支持范围之外的内容。理论上,我们证明了在自然的探索条件下,XPO 具有可证明的样本效率,并能收敛到接近最优的语言模型策略,无论初始模型是否具有良好的覆盖范围。我们的分析基于 DPO 隐式执行一种 Q*-近似(或贝尔曼误差最小化)的观察,通过 KL 正则化马尔可夫决策过程的视角,将语言建模和理论强化学习中先前分离的技术以一种偶然的方式结合起来。在经验上,我们发现 XPO 在初步评估中比非探索性 DPO 变体具有更高的样本效率。
此训练后方法由 Kashif Rasul, Quentin Gallouédec 和 Lewis Tunstall 贡献。
快速入门
此示例演示了如何使用 XPO 方法训练模型。我们使用 Qwen 0.5B 模型作为基础模型,并使用 PairRMJudge 作为评判器。我们使用 UltraFeedback 数据集 中的提示。你可以在此处查看数据集中的提示。
以下是训练模型的脚本
# train_xpo.py
from datasets import load_dataset
from trl import PairRMJudge, XPOConfig, XPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO")
trainer = XPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
使用以下命令执行脚本
accelerate launch train_xpo.py
在 8 个 GPU 上分布式训练,大约需要 1 小时。
要查看训练后模型的表现,您可以使用 Transformers Chat CLI。
$ transformers chat trl-lib/Qwen2-0.5B-XPO
<quentin_gallouedec>:
What is the best programming language?
<trl-lib/Qwen2-0.5B-XPO>:
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
预期数据集类型
XPO 需要一个仅包含提示的数据集。XPOTrainer 支持对话格式和标准格式的数据集。当提供对话格式的数据集时,训练器会自动将聊天模板应用于数据集。
使用技巧
使用奖励模型
除了评判器,您也可以选择使用奖励模型——请参阅 Reward Bench 获取可用的公开模型排行榜。以下代码示例展示了如何用 trl-lib/Qwen2-0.5B-Reward 模型替换评判器。
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
trainer = XPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
请确保 SFT 模型和奖励模型使用*相同*的聊天模板和分词器。否则,您可能会发现在训练过程中模型的补全得分不正确。
鼓励生成 EOS 标记
当使用奖励模型时,我们可能希望模型在给定的长度内生成补全。训练期间,模型将生成补全,其长度最多为 XPOConfig 的 `max_new_tokens` 参数指定的最大长度。如果您想惩罚模型在达到最大长度前未生成 EOS 标记,可以使用 XPOConfig 的 `missing_eos_penalty` 参数。
training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
日志记录补全
为了更好地理解模型在训练过程中的行为,您可以使用 LogCompletionsCallback 定期记录样本补全。
trainer = XPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
此回调函数直接将模型生成的补全记录到 Weights & Biases。
示例脚本
我们提供了一个示例脚本,用于使用 XPO 方法训练模型。该脚本位于 examples/scripts/xpo.py
要在 UltraFeedback 数据集上测试 Qwen2.5 0.5B 模型的 XPO 脚本,请运行以下命令。
python examples/scripts/xpo.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --judge pair_rm \ --dataset_name trl-lib/ultrafeedback-prompt \ --learning_rate 5.0e-7 \ --output_dir Qwen2.5-0.5B-XPO-PairRM \ --warmup_ratio 0.1 \ --push_to_hub
记录的指标
记录的指标如下:
- `loss/xpo`:完整损失中 xpo 部分的平均值。
- `loss/dpo`:完整损失中 dpo 部分的平均值。
objective/kl
:模型与参考数据之间的平均 KL 散度。objective/entropy
:模型和参考数据的平均熵。objective/model_scores
:模型补全的平均得分(根据奖励模型)。objective/ref_scores
:参考补全的平均得分(根据奖励模型)。- `objective/scores_margin`:被选中和被拒绝补全之间的平均得分差(根据外部奖励模型)。
rewards/chosen
:被选中补全的平均奖励(根据 XPO 的 DPO 隐式奖励模型)。rewards/rejected
:被拒绝补全的平均奖励(根据 XPO 的 DPO 隐式奖励模型)。rewards/accuracies
:XPO 的隐式奖励模型的准确率。rewards/margins
:被选中和被拒绝补全之间的平均奖励差(根据在线 DPO 的隐式奖励模型)。logps/chosen
:被选中补全的平均对数概率。logps/rejected
:被拒绝补全的平均对数概率。- `val/model_contain_eos_token`:模型输出包含 eos 标记的次数。
- `val/ref_contain_eos_token`:参考输出包含 eos 标记的次数。
- `alpha`:XPO 损失项的权重。通常是固定的,但可以通过向 XPOConfig 传递一个列表来使其动态化。
- `beta`:控制表示与参考模型偏差的损失项权重的参数。通常是固定的,但可以通过向 XPOConfig 传递一个列表来使其动态化。
XPOTrainer
class trl.XPOTrainer
< 来源 >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None reward_model: typing.Optional[torch.nn.modules.module.Module] = None judge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = None args: typing.Optional[trl.trainer.xpo_config.XPOConfig] = None data_collator: typing.Optional[typing.Callable] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
参数
- model (
transformers.PreTrainedModel
) — 用于训练的模型,最好是AutoModelForCausalLM
。 - ref_model (
PreTrainedModelWrapper
) — 带有因果语言模型头的 Hugging Face Transformer 模型。用于隐式奖励计算和损失。如果没有提供参考模型,训练器将创建一个与待优化模型结构相同的参考模型。 - reward_model (
transformers.PreTrainedModel
) — 用于对补全进行评分的奖励模型,最好是 `AutoModelForSequenceClassification`。 - judge (
BasePairwiseJudge
) — 用于对模型补全进行成对比较的评判器。 - args (
XPOConfig
) — 用于训练的 XPO 配置参数。 - data_collator (`transformers.DataCollator`) — 用于训练的数据整理器。如果未指定,将使用默认的数据整理器 (`DPODataCollatorWithPadding`),它会在给定成对序列的数据集时将序列填充到批次中序列的最大长度。
- train_dataset (`datasets.Dataset`) — 用于训练的数据集。
- eval_dataset (`datasets.Dataset`) — 用于评估的数据集。
- processing_class (`PreTrainedTokenizerBase`、`BaseImageProcessor`、`FeatureExtractionMixin` 或 `ProcessorMixin`,*可选*,默认为 `None`) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并与模型一起保存,以便更容易地重新运行中断的训练或重用微调后的模型。
- peft_config (`dict`) — 用于训练的 peft 配置。
- compute_metrics (`Callable[[EvalPrediction], dict]`,*可选*) — 用于计算指标的函数。必须接受一个 `EvalPrediction` 并返回一个从字符串到指标值的字典。
- callbacks (`list[transformers.TrainerCallback]`) — 用于训练的回调函数。
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`) — 用于训练的优化器和调度器。
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`) — 在计算指标之前用于预处理 logits 的函数。
将 XPOTrainer 初始化为 OnlineDPOConfig 的子类。
train
< 来源 >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
参数
- resume_from_checkpoint (`str` 或 `bool`, *可选*) — 如果是 `str`,则为之前 `Trainer` 实例保存的检查点的本地路径。如果是 `bool` 且等于 `True`,则加载 `args.output_dir` 中由之前 `Trainer` 实例保存的最后一个检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态恢复。
- trial (`optuna.Trial` 或 `dict[str, Any]`,*可选*) — 用于超参数搜索的试验运行或超参数字典。
- ignore_keys_for_eval (`list[str]`, *可选*) — 一个包含模型输出中(如果输出是字典)在训练期间收集评估预测时应忽略的键的列表。
- kwargs (`dict[str, Any]`, *可选*) — 用于隐藏已弃用参数的附加关键字参数。
主训练入口点。
将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。
仅从主进程保存。
push_to_hub
< 来源 >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
参数
- commit_message (`str`, *可选*,默认为 `"End of training"`) — 推送时的提交信息。
- blocking (`bool`, *可选*,默认为 `True`) — 函数是否应该在 `git push` 完成后才返回。
- token (`str`, *可选*,默认为 `None`) — 具有写权限的令牌,用于覆盖 Trainer 的原始参数。
- revision (`str`, *可选*) — 要提交的 git 修订版本。默认为“main”分支的头部。
- kwargs (
dict[str, Any]
, 可选) — 传递给~Trainer.create_model_card
的额外关键字参数。
将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。
XPOConfig
class trl.XPOConfig
< 源 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True reward_model_path: typing.Optional[str] = None judge: typing.Optional[str] = None max_new_tokens: int = 64 max_length: int = 512 temperature: float = 0.9 missing_eos_penalty: typing.Optional[float] = None beta: list = <factory> loss_type: str = 'sigmoid' dataset_num_proc: typing.Optional[int] = None disable_dropout: bool = True use_vllm: bool = False vllm_model_impl: str = 'vllm' gpu_memory_utilization: typing.Optional[float] = 0.55 ds3_gather_for_generation: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None alpha: list = <factory> )
用于 XPOTrainer 的配置类。
OnlineDPOConfig 的子类,我们可以使用它的所有参数并添加以下内容