TRL 文档

KTO Trainer

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

来入门

KTO Trainer

TRL 支持卡尼曼-特沃斯基优化(KTO)训练器,用于调整语言模型与二进制反馈数据(例如,赞成/反对),如 Kawin Ethayarajh、Winnie Xu、Niklas Muennighoff、Dan Jurafsky 和 Douwe Kiela 在论文中所述。如需了解完整示例,请查看examples/scripts/kto.py

根据基本模型的优劣,在 KTO 之前可能需要或不需要执行 SFT。这不同于标准的 RLHF 和 DPO,它们始终需要 SFT。

预期数据集格式

KTO 训练器对数据集有非常特殊的要求,因为它不需要成对的偏好。由于该模型将经过训练来直接优化由提示、模型完成和一个指示完成是“好”还是“坏”的标签组成的示例,我们希望数据集包含以下列:

  • 提示
  • 完成
  • 标签

例如

kto_dataset_dict = {
    "prompt": [
        "Hey, hello",
        "How are you",
        "What is your name?",
        "What is your name?",
        "Which is the best programming language?",
        "Which is the best programming language?",
        "Which is the best programming language?",
    ],
    "completion": [
        "hi nice to meet you",
        "leave me alone",
        "I don't have a name",
        "My name is Mary",
        "Python",
        "C++",
        "Java",
    ],
    "label": [
        True,
        False,
        False,
        True,
        True,
        False,
        False,
    ],
}

其中提示包含文本输入,完成包含相应的响应,标签包含指示生成完成是否需要(True)或不需要(False)的相应标记。一个提示可以有多个响应,反映在字典值数组中的条目重复。数据集必须至少包含一个需要的完成和一个不需要的完成。

预期模型格式

KTO 训练器期望一个 AutoModelForCausalLM 模型,而 PPO 期望将 AutoModelForCausalLMWithValueHead 用于该值函数。

使用 KTOTrainer

如需详细示例,请参阅 examples/scripts/kto.py 脚本。从一般层面讲,我们需要使用我们希望训练的 model 和一个参考 ref_model(我们用它来计算 preferred 和 rejected 响应的隐式回报)初始化 KTOTrainer

beta 指隐式回报的超参数,数据集包含以上列出的 3 个条目。请注意,modelref_model 需要有相同的架构(即仅解码器,或编码器-解码器)。

desirable_weightundesirable_weight 指对 desirable/positive 和 undesirable/negative 示例的损失赋予的权重。默认情况下,两者都为 1。但是,如果您有一个比另一个更多,则应该提升较少类型的权重,使得(desirable_weight * 积极示例数)与(undesirable_weight * 消极示例数)的比率在 1:1 到 4:3 之间。

training_args = KTOConfig(
    beta=0.1,
    desirable_weight=1.0,
    undesirable_weight=1.0,
)

kto_trainer = KTOTrainer(
    model,
    model_ref,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

完成此操作后,即可调用

kto_trainer.train()

损失函数

给定指示一个完成对于提示是 desirable 还是 undesirable 的二进制信号数据,我们可以优化一个与卡尼曼-特沃斯基前景理论的原则一致的隐式回报函数,例如参照依存、损失厌恶和效用递减。

BCO 作者训练了一个二进制分类器,其 logit 表示奖励,以便分类器将 {提示、选择的完成} 对映射到 1,将 {提示、拒绝的完成} 对映射到 0。可以使用 loss_type="bco" 参数将 KTOTrainer 切换到此损失。

对于专家混合模型:启用辅助损失

如果负载在大致均匀地分布在专家之间时,MOE 是最有效的。
为了确保我们在偏好调节期间以类似方式训练 MOE,最好将负载均衡器的辅助损失添加到最终损失中。

可以通过在模型配置(例如 MixtralConfig)中设置 output_router_logits=True 来启用此选项。
要调整辅助损失对总损失的贡献程度,请使用超参数 router_aux_loss_coef=...(默认值:0.001)。

KTOTrainer

trl.KTOTrainer

< >

( model: Union = None ref_model: Union = None args: KTOConfig = None train_dataset: Optional = None eval_dataset: Union = None tokenizer: Optional = None data_collator: Optional = None model_init: Optional = None callbacks: Optional = None optimizers: Tuple = (None, None) preprocess_logits_for_metrics: Optional = None peft_config: Optional = None compute_metrics: Optional = None model_adapter_name: Optional = None ref_adapter_name: Optional = None embedding_func: Optional = None embedding_tokenizer: Optional = None )

参数

  • model (transformers.PreTrainedModel) — 要训练的模型,最好是 AutoModelForSequenceClassification
  • ref_model (PreTrainedModelWrapper) — 带有随意语言建模头的 Hugging Face Transformer 模型。用于隐式奖励计算和损失。如果没有提供引用模型,则训练器将创建与要优化的模型具有相同架构的引用模型。
  • args (KTOConfig) — 训练中使用的参数。
  • train_dataset (datasets.Dataset) — 训练中使用的数据集。
  • tokenizer (transformers.PreTrainedTokenizerBase) - 训练所用的 tokenizer。如果您想使用默认数据整理程序,此参数是必需的。
  • data_collator (transformers.DataCollator可选,默认为 None) - 训练时使用的数据整理程序。如果指定了 None,将使用默认的数据整理程序 (DPODataCollatorWithPadding),该程序会将序列填充至批处理中序列的最大长度,条件是给定了一组配对序列。
  • model_init (Callable[[], transformers.PreTrainedModel]) — 训练时使用的模型初始化器。如果未指定,将使用默认模型初始化器。
  • callbacks (List[transformers.TrainerCallback]) — 训练时使用的回调函数。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 测量值计算之前用于预处理logits的函数。
  • peft_config (Dict,默认为 None) — 用于训练的 PEFT 配置。如果您传递 PEFT 配置,该模型将封装在 PEFT 模型中。
  • compute_metrics (Callable[[EvalPrediction], Dict]可选) — 用于计算指标的函数。必须采用 EvalPrediction 并返回字典字符串以获取指标值。
  • model_adapter_name (str,默认为 None) — 使用 LoRA 搭配多个适配器时,训练目标 PEFT 适配器的名称。
  • ref_adapter_name (str, 默认为 None) — 使用带有多个适配器的 LoRA 时,参考 PEFT 适配器的名称。

初始化 KTOTrainer。

bco_loss

< >

( policy_chosen_logps: FloatTensor policy_rejected_logps: FloatTensor reference_chosen_logps: FloatTensor reference_rejected_logps: FloatTensor chosen_embeddings: Optional rejected_embeddings: Optional ) 四元组张量

返回值

四个张量的元组

(损失,选择的奖励,拒绝的奖励,KL)。损失张量包含每批样本的 KTO 损失。chosen_rewards 和 rejected_rewards 张量分别包含所选和拒绝响应的奖励。delta 值包含所有隐式奖励的移动平均值。

计算一批策略和参考模型的对数概率的 BCO 损失。

compute_reference_log_probs

< >

( padded_batch: Dict )

计算 KTO 特定数据集的一个单一填充批次的参考模型的对数概率。

evaluation_loop

< >

( dataloader: DataLoader description: str prediction_loss_only: Optional = None ignore_keys: Optional = None metric_key_prefix: str = 'eval' )

覆盖内置评估循环以存储每批次的指标。 预测/评估循环,由 Trainer.evaluate()Trainer.predict() 共享。

有无标签都可以使用。

get_batch_logps

< >

( logits: FloatTensor labels: LongTensor average_log_prob: bool = False label_pad_token_id: int = -100 is_encoder_decoder: bool = False )

获取指定 logits 下指定标签的对数概率。

get_batch_loss_metrics

< >

( model batch: Dict )

为训练或测试计算给定一批输入的 KTO 损失和其他指标。

get_batch_samples

< >

( model batch: Dict )

根据给定的输入批次从模型和参考模型生成样本。

get_eval_dataloader

< >

( eval_dataset: 可选 = 无 )

参数

  • eval_dataset (torch.utils.data.Dataset, 可选) — 如果提供,将覆盖 self.eval_dataset。如果它是一个 Dataset,则 model.forward() 方法不接受的列将自动删除。它必须实现 __len__

返回评估 ~torch.utils.data.DataLoader

transformers.src.transformers.trainer.get_eval_dataloader 的子类,用于预计算 ref_log_probs

get_train_dataloader

< >

( )

返回训练 ~torch.utils.data.DataLoader

transformers.src.transformers.trainer.get_train_dataloader 的子类,用于预计算 ref_log_probs

kto_loss

< >

( policy_chosen_logps: FloatTensor policy_rejected_logps: FloatTensor policy_KL_logps: FloatTensor reference_chosen_logps: FloatTensor reference_rejected_logps: FloatTensor reference_KL_logps: FloatTensor ) 一个四元组

返回值

四个张量的元组

(损失、选择的奖励、拒绝的奖励、KL)。损失张量包含批次中每个示例的 KTO 损失。选择和拒绝的奖励张量分别包含选择和拒绝的回复的奖励。KL 张量包含策略和参考模型之间的分离的 KL 差异估计。

计算一批策略和参考模型对数概率的 KTO 损失。

日志

< >

( logs: Dict )

参数

  • logs (Dict[str, float]) — 待记录的值。

记录 logs 到观察训练的不同对象上,包括存储的指标。

KTOConfig

class trl.KTOConfig

< >

( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: Union = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: Optional = None per_gpu_eval_batch_size: Optional = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: Optional = None eval_delay: Optional = 0 torch_empty_cache_steps: Optional = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: Union = 'linear' lr_scheduler_kwargs: Union = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: Optional = 'passive' log_level_replica: Optional = 'warning' log_on_each_node: bool = True logging_dir: Optional = None logging_strategy: Union = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: Union = 'steps' save_steps: float = 500 save_total_limit: Optional = None save_safetensors: Optional = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: Optional = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: Optional = None local_rank: int = -1 ddp_backend: Optional = None tpu_num_cores: Optional = None tpu_metrics_debug: bool = False debug: Union = '' dataloader_drop_last: bool = False eval_steps: Optional = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: Optional = None past_index: int = -1 run_name: Optional = None disable_tqdm: Optional = None remove_unused_columns: Optional = True label_names: Optional = None load_best_model_at_end: Optional = False metric_for_best_model: Optional = None greater_is_better: Optional = None ignore_data_skip: bool = False fsdp: Union = '' fsdp_min_num_params: int = 0 fsdp_config: Union = None fsdp_transformer_layer_cls_to_wrap: Optional = None accelerator_config: Union = None deepspeed: Union = None label_smoothing_factor: float = 0.0 optim: Union = 'adamw_torch' optim_args: Optional = None adafactor: bool = False group_by_length: bool = False length_column_name: Optional = 'length' report_to: Union = None ddp_find_unused_parameters: Optional = None ddp_bucket_cap_mb: Optional = None ddp_broadcast_buffers: Optional = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: Optional = None hub_model_id: Optional = None hub_strategy: Union = 'every_save' hub_token: Optional = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: Union = None include_inputs_for_metrics: bool = False eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: Union = None push_to_hub_model_id: Optional = None push_to_hub_organization: Optional = None push_to_hub_token: Optional = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: Optional = None ray_scope: Optional = 'last' ddp_timeout: Optional = 1800 torch_compile: bool = False torch_compile_backend: Optional = None torch_compile_mode: Optional = None dispatch_batches: Optional = None split_batches: Optional = None include_tokens_per_second: Optional = False include_num_input_tokens_seen: Optional = False neftune_noise_alpha: Optional = None optim_target_modules: Union = None batch_eval_metrics: bool = False eval_on_start: bool = False eval_use_gather_object: Optional = False max_length: Optional = None max_prompt_length: Optional = None max_completion_length: Optional = None beta: float = 0.1 desirable_weight: Optional = 1.0 undesirable_weight: Optional = 1.0 label_pad_token_id: int = -100 padding_value: int = None truncation_mode: str = 'keep_end' generate_during_eval: bool = False is_encoder_decoder: Optional = None precompute_ref_log_probs: bool = False model_init_kwargs: Optional = None ref_model_init_kwargs: Optional = None dataset_num_proc: Optional = None loss_type: Literal = 'kto' prompt_sample_size: int = 1024 min_density_ratio: float = 0.5 max_density_ratio: float = 10.0 )

参数

  • max_prompt_length (int, 可选,默认为 None) — 提示的最大长度。如果你想使用默认数据收集器,则需要此参数。
  • max_completion_length (int, 可选,默认为 None) — 目标的最大长度。如果你想使用默认数据收集器并且你的模型是一个编码器/解码器,则需要此参数。
  • Beta 版 (float,默认为 0.1) — KTO 损失中的 beta 因子。更高的 beta 意味着与初始策略的差异更小。
  • desirable_weight (float可选,默认为 1.0) — desirable 损失会通过此因子加权,以应对 desirable 和 undesirable 对数量不均衡的情况。
  • label_pad_token_id (int,默认为 -100) — 标签填充标记 id。如果您希望使用默认数据整理程序,则需要此参数。
  • padding_value (int,默认为 0) — 如果与标记器 `pad_token_id` 不同,则需要填充值。
  • truncation_mode (str,默认为 keep_end) — 使用的截断模式,keep_endkeep_start。如果你想使用默认数据整理器,此参数是必需的。
  • generate_during_eval (bool,默认为 False) — 是否在评估步骤中抽取和记录生成。
  • is_encoder_decoder (Optional[bool], optional, 默认为 None) — 如果未提供模型,我们需要知道 model_init 是否返回编码器-解码器。
  • precompute_ref_log_probs (bool, 默认为 False) — 训练和评估数据集的预计算参考模型日志概率标志。如果您想在没有参考模型的情况下进行训练并减少所需的总 GPU 内存,这将很有用。 model_init_kwargs — (Optional[Dict], 可选): 从字符串实例化模型时传递的可选 kwargs 词典。 ref_model_init_kwargs — (Optional[Dict], 可选): 从字符串实例化 ref 模型时传递的可选 kwargs 词典。 dataset_num_proc — (Optional[int], 可选, 默认为 None): 用于处理数据集的进程数。 loss_type — (Literal["kto", "bco"], 可选): 要使用的损失类型。 "kto" 是默认 KTO 损失,"bco" 损失来自 BCO 论文。 prompt_sample_size — (int, 默认为 1024): 馈送到密度比分类器的提示数。 min_density_ratio — (float, 默认为 0.5): 密度比的最小值。估计的密度比被限定为该值。 max_density_ratio — (float, 默认为 10.0): 密度比的最大值。估计的密度比被限定为该值。

KTOConfig 收集与 KTOTrainer 类相关的训练参数。

使用 HfArgumentParser,我们可以将此类转变为 argparse 参数,可在命令行中指定这些参数。

< > 更新,已在 GitHub 上(Update