TRL 文档
CPO 训练器
并获得增强的文档体验
开始使用
CPO 训练器
概述
对比偏好优化(Contrastive Preference Optimization,CPO)由 Haoran Xu、Amr Sharaf、Yunmo Chen、Weiting Tan、Lingfeng Shen、Benjamin Van Durme、Kenton Murray 和 Young Jin Kim 在论文 Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation 中提出。概括来说,CPO 训练模型在机器翻译(MT)任务中避免生成足够好但并非完美的翻译。然而,CPO 是 DPO 损失的一般近似,可以应用于其他领域,例如聊天。
CPO 旨在缓解 SFT 的两个基本缺点。首先,SFT 最小化预测输出与黄金标准参考之间差异的方法,本质上将模型性能限制在训练数据的质量水平。其次,SFT 缺乏一种机制来防止模型拒绝翻译中的错误。CPO 目标是从 DPO 目标派生而来的。
快速入门
本示例演示了如何使用 CPO 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型。我们使用来自 UltraFeedback 数据集 的偏好数据。你可以在此处查看数据集中的数据。
以下是训练模型的脚本
# train_cpo.py
from datasets import load_dataset
from trl import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO")
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
使用以下命令执行脚本
accelerate launch train_cpo.py
预期的数据集类型
CPO 需要一个 偏好数据集。CPOTrainer 支持 对话式 和 标准 两种数据集格式。当提供对话式数据集时,训练器会自动将聊天模板应用于数据集。
示例脚本
我们提供了一个示例脚本来使用 CPO 方法训练模型。该脚本位于 examples/scripts/cpo.py
要使用 Qwen2 0.5B 模型 在 UltraFeedback 数据集 上测试 CPO 脚本,请运行以下命令
accelerate launch examples/scripts/cpo.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ --output_dir Qwen2-0.5B-CPO
记录的指标
在训练和评估期间,我们记录以下奖励指标
rewards/chosen
:策略模型对选定响应的平均对数概率,按 beta 缩放rewards/rejected
:策略模型对拒绝响应的平均对数概率,按 beta 缩放rewards/accuracies
:所选奖励大于相应被拒奖励的频率的平均值rewards/margins
:选中奖励和相应拒绝奖励之间的平均差异nll_loss
:策略模型对选定响应的平均负对数似然损失
CPO 变体
简单偏好优化 (SimPO)
SimPO 方法也在 CPOTrainer 中实现。SimPO 是一种替代损失函数,它增加了一个奖励边际,允许长度归一化,并且不使用 BC 正则化。要使用此损失,我们可以在 CPOConfig 中将 loss_type="simpo"
和 cpo_alpha=0.0
来轻松使用 SimPO。
CPO-SimPO
我们还提供 CPO 和 SimPO 的组合使用,可以实现更稳定的训练和更好的性能。详情请见 CPO-SimPO GitHub。要使用此方法,只需在 CPOConfig 中设置 loss_type="simpo"
并设置一个非零的 cpo_alpha
来启用 SimPO。
损失函数
CPO 算法支持多种损失函数。可以通过在 CPOConfig 中使用 loss_type
参数来设置损失函数。支持以下损失函数
loss_type= | 描述 |
---|---|
"sigmoid" (默认) | 给定偏好数据,我们可以根据 Bradley-Terry 模型拟合一个二元分类器,事实上 DPO 的作者提出通过 logsigmoid 对归一化似然使用 sigmoid 损失来拟合逻辑回归。 |
"hinge" | RSO 的作者提议使用来自 SLiC 论文的铰链损失(hinge loss)来处理归一化似然。在这种情况下,beta 是边际的倒数。 |
"ipo" | IPO 的作者对 DPO 算法提供了更深入的理论理解,并指出了过拟合问题,提出了一种替代损失。在这种情况下,beta 是选定完成与拒绝完成对的对数似然比之间差距的倒数,因此 beta 越小,这个差距就越大。根据论文,损失是对完成的对数似然进行平均(与 DPO 不同,DPO 仅进行求和)。 |
对于混合专家模型:启用辅助损失
如果负载在专家之间大致均匀分布,MOE(专家混合模型)效率最高。
为了确保在偏好调整期间类似地训练 MOE,将负载均衡器的辅助损失添加到最终损失中是有益的。
通过在模型配置(例如 MixtralConfig
)中设置 output_router_logits=True
来启用此选项。
要调整辅助损失对总损失的贡献程度,请在模型配置中使用超参数 router_aux_loss_coef=...
(默认值:0.001
)。
CPOTrainer
class trl.CPOTrainer
< 源代码 >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None args: typing.Optional[trl.trainer.cpo_config.CPOConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None )
参数
- model (
transformers.PreTrainedModel
) — 用于训练的模型,最好是AutoModelForSequenceClassification
。 - args (
CPOConfig
) — 用于训练的 CPO 配置参数。 - data_collator (
transformers.DataCollator
) — 用于训练的数据整理器。如果未指定,将使用默认的数据整理器 (DPODataCollatorWithPadding
),它会根据批次中序列的最大长度对序列进行填充,适用于成对序列的数据集。 - train_dataset (
datasets.Dataset
) — 用于训练的数据集。 - eval_dataset (
datasets.Dataset
) — 用于评估的数据集。 - processing_class (
PreTrainedTokenizerBase
,BaseImageProcessor
,FeatureExtractionMixin
orProcessorMixin
, 可选, 默认为None
) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并与模型一起保存,以便更容易地重新运行中断的训练或重用微调后的模型。 - model_init (
Callable[[], transformers.PreTrainedModel]
) — 用于训练的模型初始化器。如果未指定,将使用默认的模型初始化器。 - callbacks (
list[transformers.TrainerCallback]
) — 用于训练的回调函数。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用于训练的优化器和学习率调度器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 用于在计算指标前预处理 logits 的函数。 - peft_config (
dict
, 默认为None
) — 用于训练的 PEFT 配置。如果传递 PEFT 配置,模型将被包装为 PEFT 模型。 - compute_metrics (
Callable[[EvalPrediction], dict]
, 可选) — 用于计算指标的函数。必须接受一个EvalPrediction
并返回一个从字符串到指标值的字典。
初始化 CPOTrainer。
train
< 源代码 >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
参数
- resume_from_checkpoint (
str
orbool
, 可选) — 如果是str
,则是之前Trainer
实例保存的检查点的本地路径。如果是bool
且等于True
,则加载之前Trainer
实例在 *args.output_dir* 中保存的最后一个检查点。如果存在,训练将从此处加载的模型/优化器/调度器状态恢复。 - trial (
optuna.Trial
ordict[str, Any]
, 可选) — 用于超参数搜索的试验运行或超参数字典。 - ignore_keys_for_eval (
list[str]
, 可选) — 在训练期间收集评估预测时,模型输出中(如果输出是字典)应忽略的键列表。 - kwargs (
dict[str, Any]
, 可选) — 用于隐藏已弃用参数的额外关键字参数。
主训练入口点。
将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。
仅从主进程保存。
push_to_hub
< 源代码 >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
参数
将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。
CPOConfig
class trl.CPOConfig
< 源代码 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True max_length: typing.Optional[int] = 1024 max_prompt_length: typing.Optional[int] = 512 max_completion_length: typing.Optional[int] = None beta: float = 0.1 label_smoothing: float = 0.0 loss_type: str = 'sigmoid' disable_dropout: bool = True cpo_alpha: float = 1.0 simpo_gamma: float = 0.5 label_pad_token_id: int = -100 padding_value: typing.Optional[int] = None truncation_mode: str = 'keep_end' generate_during_eval: bool = False is_encoder_decoder: typing.Optional[bool] = None model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None )
参数
- max_length (
int
orNone
, optional, defaults to1024
) — 批次中序列(提示+补全)的最大长度。如果要使用默认的数据整理器,则此参数是必需的。 - max_prompt_length (
int
orNone
, optional, defaults to512
) — 提示的最大长度。如果要使用默认的数据整理器,则此参数是必需的。 - max_completion_length (
int
orNone
, optional, defaults toNone
) — 补全的最大长度。如果要使用默认的数据整理器且模型是编码器-解码器模型,则此参数是必需的。 - beta (
float
, optional, defaults to0.1
) — 控制与参考模型偏差的参数。较高的 β 意味着与参考模型的偏差较小。对于 IPO 损失 (loss_type="ipo"
),β 是论文中表示为 τ 的正则化参数。 - label_smoothing (
float
, optional, defaults to0.0
) — 标签平滑因子。如果要使用默认的数据整理器,则此参数是必需的。 - loss_type (
str
, optional, defaults to"sigmoid"
) — 要使用的损失类型。可能的值有: - disable_dropout (
bool
, optional, defaults toTrue
) — 是否在模型中禁用 dropout。 - cpo_alpha (
float
, optional, defaults to1.0
) — CPO 训练中 BC 正则化器的权重。 - simpo_gamma (
float
, optional, defaults to0.5
) — SimPO 损失的目标奖励边际,仅在loss_type="simpo"
时使用。 - label_pad_token_id (
int
, optional, defaults to-100
) — 标签填充标记 ID。如果要使用默认的数据整理器,则此参数是必需的。 - padding_value (
int
orNone
, optional, defaults toNone
) — 要使用的填充值。如果为None
,则使用分词器的填充值。 - truncation_mode (
str
,optional, defaults to"keep_end"
) — 当提示过长时使用的截断模式。可能的值为"keep_end"
或"keep_start"
。如果要使用默认的数据整理器,则此参数是必需的。 - generate_during_eval (
bool
, optional, defaults toFalse
) — 如果为True
,则在评估期间生成模型的补全并记录到 W&B 或 Comet。 - is_encoder_decoder (
bool
orNone
, optional, defaults toNone
) — 当使用model_init
参数(可调用函数)来实例化模型而不是model
参数时,你需要指定该可调用函数返回的模型是否为编码器-解码器模型。 - model_init_kwargs (
dict[str, Any]
orNone
, optional, defaults toNone
) — 从字符串实例化模型时,传递给AutoModelForCausalLM.from_pretrained
的关键字参数。 - dataset_num_proc (
int
orNone
, optional, defaults toNone
) — 用于处理数据集的进程数。
CPOTrainer 的配置类。
此类仅包含 CPO 训练特有的参数。有关训练参数的完整列表,请参阅 TrainingArguments
文档。请注意,此类中的默认值可能与 TrainingArguments
中的默认值不同。
使用 HfArgumentParser
,我们可以将此类别转换为可在命令行上指定的 argparse 参数。