TRL 文档

KTO 训练器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

KTO 训练器

概述

Kahneman-Tversky 优化 (KTO) 由 KTO: Model Alignment as Prospect Theoretic Optimization 论文引入,作者包括 Kawin EthayarajhWinnie XuNiklas Muennighoff、Dan Jurafsky 和 Douwe Kiela

该论文的摘要如下:

Kahneman & Tversky 的前景理论告诉我们,人类以一种有偏差但定义明确的方式感知随机变量;例如,众所周知,人类是厌恶损失的。我们表明,使 LLM 与人类反馈对齐的目标隐含地包含了许多这些偏差——这些目标(例如,DPO)相对于交叉熵最小化的成功,部分可以归因于它们是人类感知的损失函数(HALO)。然而,这些方法归因于人类的效用函数仍然与前景理论文献中的效用函数不同。使用 Kahneman-Tversky 人类效用模型,我们提出了一种 HALO,它直接最大化生成的效用,而不是像当前方法那样最大化偏好的对数似然。我们将这种方法称为 Kahneman-Tversky 优化 (KTO),它在 1B 到 30B 的规模下,性能与基于偏好的方法相当或超过。至关重要的是,KTO 不需要偏好——只需要一个二元信号,表明对于给定的输入,输出是理想的还是不理想的。这使得它在现实世界中更容易使用,因为在现实世界中,偏好数据稀缺且昂贵。

官方代码可以在 ContextualAI/HALOs 中找到。

这种后训练方法由 Kashif RasulYounes BelkadaLewis Tunstall 和 Pablo Vicente 贡献。

快速开始

此示例演示了如何使用 KTO 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型。我们使用来自 KTO Mix 14k 的偏好数据。您可以在此处的数据集中查看数据

以下是训练模型的脚本

# train_kto.py
from datasets import load_dataset
from trl import KTOConfig, KTOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")

training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10)
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()

使用以下命令执行脚本

accelerate launch train_kto.py

在 8 个 H100 GPU 上分布式执行,训练大约需要 30 分钟。您可以通过查看奖励图来验证训练进度。奖励边际的增加趋势表明模型正在改进,并且随着时间的推移生成更好的响应。

要查看 训练好的模型 的性能,您可以使用 Transformers Chat CLI

$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
<quentin_gallouedec>:
What is the best programming language?

<trl-lib/Qwen2-0.5B-KTO>:
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:                                                                                  

Here are some other factors to consider when choosing a programming language for a project:

 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.                                                                   
 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.                                                                                                                                                            
 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.                                                                                                                                         
 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.   

预期数据集格式

KTO 需要一个 非成对偏好数据集。或者,您可以提供一个成对偏好数据集(也简称为偏好数据集)。在这种情况下,训练器将通过分离选定的和拒绝的响应,自动将其转换为非成对格式,为选定的完成分配 label = True,为拒绝的完成分配 label = False

KTOTrainer 支持对话式标准数据集格式。当提供对话式数据集时,训练器将自动将聊天模板应用于数据集。

理论上,数据集应至少包含一个选定的和一个拒绝的完成。但是,一些用户已成功使用 选定或仅拒绝的数据运行 KTO。如果仅使用拒绝的数据,建议采用保守的学习率。

示例脚本

我们提供了一个示例脚本,用于使用 KTO 方法训练模型。该脚本可在 trl/scripts/kto.py 中找到

要在 UltraFeedback 数据集上使用 Qwen2 0.5B 模型 测试 KTO 脚本,请运行以下命令

accelerate launch trl/scripts/kto.py \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --dataset_name trl-lib/kto-mix-14k \
    --num_train_epochs 1 \
    --logging_steps 25 \
    --output_dir Qwen2-0.5B-KTO

使用技巧

对于专家混合模型:启用辅助损失

如果专家之间的负载大致均衡分布,则 MOE 最有效。
为了确保我们在偏好调整期间以类似的方式训练 MOE,最好将负载均衡器的辅助损失添加到最终损失中。

通过在模型配置中设置 output_router_logits=True(例如 MixtralConfig)来启用此选项。
要调整辅助损失对总损失的贡献程度,请在模型配置中使用超参数 router_aux_loss_coef=...(默认值:0.001)。

批量大小建议

使用每步至少为 4 的批量大小,以及 16 到 128 之间的有效批量大小。即使您的有效批量大小很大,如果您的每步批量大小较差,则 KTO 中的 KL 估计也会较差。

学习率建议

每个 beta 选择都有其可容忍的最大学习率,超过此学习率,学习性能会下降。对于 beta = 0.1 的默认设置,对于大多数模型,学习率通常不应超过 1e-6。随着 beta 降低,学习率也应相应降低。一般来说,我们强烈建议将学习率保持在 5e-75e-6 之间。即使使用小型数据集,我们也不建议使用此范围以外的学习率。相反,选择更多 epoch 以获得更好的结果。

不平衡数据

KTOConfigdesirable_weightundesirable_weight 指的是理想/正面示例和不理想/负面示例的损失权重。默认情况下,它们都为 1。但是,如果您其中一种类型的示例更多,则应提高较不常见类型的权重,以使(desirable_weight×\times正例数量)与(undesirable_weight×\times负例数量)的比率在 1:1 到 4:3 的范围内。

记录的指标

在训练和评估期间,我们记录以下奖励指标

  • rewards/chosen_sum:策略模型对于选定响应的对数概率之和,按 beta 缩放
  • rewards/rejected_sum:策略模型对于拒绝响应的对数概率之和,按 beta 缩放
  • logps/chosen_sum:选定完成的对数概率之和
  • logps/rejected_sum:拒绝完成的对数概率之和
  • logits/chosen_sum:选定完成的 logits 之和
  • logits/rejected_sum:拒绝完成的 logits 之和
  • count/chosen:批次中选定样本的计数
  • count/rejected:批次中拒绝样本的计数

KTOTrainer

class trl.KTOTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None args: KTOConfig = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None model_adapter_name: typing.Optional[str] = None ref_adapter_name: typing.Optional[str] = None )

参数

  • model (transformers.PreTrainedModel) — 要训练的模型,最好是 AutoModelForSequenceClassification
  • ref_model (PreTrainedModelWrapper) — Hugging Face transformer 模型,带有因果语言建模头。用于隐式奖励计算和损失。如果未提供参考模型,则训练器将创建一个与要优化的模型具有相同架构的参考模型。
  • args (KTOConfig) — 用于训练的参数。
  • train_dataset (datasets.Dataset) — 用于训练的数据集。
  • eval_dataset (datasets.Dataset) — 用于评估的数据集。
  • processing_class (PreTrainedTokenizerBaseBaseImageProcessorFeatureExtractionMixinProcessorMixin, 可选) — 用于处理数据的处理类。 如果提供,将用于自动处理模型的输入,并将与模型一起保存,以便更容易地重新运行中断的训练或重用微调模型。
  • data_collator (transformers.DataCollator, 可选, 默认为 None) — 用于训练的数据收集器。 如果未指定,将使用默认数据收集器 (DPODataCollatorWithPadding),它将根据批次中序列的最大长度填充序列,给定成对序列的数据集。
  • model_init (Callable[[], transformers.PreTrainedModel]) — 用于训练的模型初始化器。 如果未指定,将使用默认模型初始化器。
  • callbacks (list[transformers.TrainerCallback]) — 用于训练的回调。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用于训练的优化器和调度器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 用于在计算指标之前预处理 logits 的函数。
  • peft_config (dict, 默认为 None) — 用于训练的 PEFT 配置。 如果您传递 PEFT 配置,模型将被包装在 PEFT 模型中。
  • compute_metrics (Callable[[EvalPrediction], dict], 可选) — 用于计算指标的函数。 必须接受 EvalPrediction 并返回从字符串到指标值的字典。
  • model_adapter_name (str, 默认为 None) — 当 LoRA 与多个适配器一起使用时,训练目标 PEFT 适配器的名称。
  • ref_adapter_name (str, 默认为 None) — 当 LoRA 与多个适配器一起使用时,参考 PEFT 适配器的名称。

初始化 KTOTrainer。

compute_reference_log_probs

< >

( padded_batch: dict )

计算 KTO 特定数据集的单个填充批次的参考模型的对数概率。

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

参数

  • model_name (strNone, 可选, 默认为 None) — 模型的名称。
  • dataset_name (strNone, 可选, 默认为 None) — 用于训练的数据集的名称。
  • tags (str, list[str]None, 可选, 默认为 None) — 要与模型卡关联的标签。

使用 Trainer 可用的信息创建模型卡的草稿。

evaluation_loop

< >

( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[list[str]] = None metric_key_prefix: str = 'eval' )

覆盖内置的评估循环以存储每个批次的指标。 Trainer.evaluate()Trainer.predict() 共享的预测/评估循环。

有标签或无标签均可使用。

generate_from_model_and_ref

< >

( model batch: dict )

从模型和参考模型中为给定的输入批次生成样本。

get_batch_logps

< >

( logits: FloatTensor labels: LongTensor average_log_prob: bool = False label_pad_token_id: int = -100 is_encoder_decoder: bool = False )

参数

  • logits — 模型的 Logits (未归一化)。 形状:(batch_size, sequence_length, vocab_size)
  • labels — 用于计算对数概率的标签。 值 label_pad_token_id 的标签标记将被忽略。 形状:(batch_size, sequence_length)
  • average_log_prob — 如果为 True,则返回每个(非掩码)标记的平均对数概率。 否则,返回(非掩码)标记的对数概率之和。

计算给定 logits 下给定标签的对数概率。

get_batch_loss_metrics

< >

( model batch: dict )

计算给定输入批次的 KTO 损失和其他指标,用于训练或测试。

get_eval_dataloader

< >

( eval_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None )

参数

  • eval_dataset (torch.utils.data.Dataset, 可选) — 如果提供,将覆盖 self.eval_dataset。 如果它是 Dataset,则会自动删除 model.forward() 方法不接受的列。 它必须实现 __len__

返回评估 ~torch.utils.data.DataLoader

transformers.src.transformers.trainer.get_eval_dataloader 的子类,用于预计算 ref_log_probs

get_train_dataloader

< >

( )

返回训练 ~torch.utils.data.DataLoader

transformers.src.transformers.trainer.get_train_dataloader 的子类,用于预计算 ref_log_probs

kto_loss

< >

( policy_chosen_logps: FloatTensor policy_rejected_logps: FloatTensor policy_KL_logps: FloatTensor reference_chosen_logps: FloatTensor reference_rejected_logps: FloatTensor reference_KL_logps: FloatTensor ) 一个包含四个张量的元组

参数

  • policy_chosen_logps — 策略模型对于被选择响应的对数概率。形状: (batch_size 中的 num(chosen),)
  • policy_rejected_logps — 策略模型对于被拒绝响应的对数概率。形状: (batch_size 中的 num(rejected),)
  • policy_KL_logps — 策略模型对于 KL 响应的对数概率。形状: (batch_size,)
  • reference_chosen_logps — 参考模型对于被选择响应的对数概率。形状: (batch_size 中的 num(chosen),)
  • reference_rejected_logps — 参考模型对于被拒绝响应的对数概率。形状: (batch_size 中的 num(rejected),)
  • reference_KL_logps — 参考模型对于 KL 响应的对数概率。形状: (batch_size,)

返回

一个包含四个张量的元组

(losses, chosen_rewards, rejected_rewards, KL)。losses 张量包含批次中每个样本的 KTO 损失。chosen_rewards 和 rejected_rewards 张量分别包含被选择和被拒绝响应的奖励。KL 张量包含策略模型和参考模型之间分离的 KL 散度估计值。

计算一批策略模型和参考模型对数概率的 KTO 损失。

log

< >

( logs: dict start_time: typing.Optional[float] = None )

参数

  • logs (dict[str, float]) — 要记录的值的字典。
  • start_time (floatNone, 可选, 默认为 None) — 训练的开始时间。

在各种监控训练的对象(包括存储的指标)上记录 logs

null_ref_context

< >

( )

用于处理空参考模型的上下文管理器(即,peft 适配器操作)。

KTOConfig

class trl.KTOConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False max_length: typing.Optional[int] = 1024 max_prompt_length: typing.Optional[int] = 512 max_completion_length: typing.Optional[int] = None beta: float = 0.1 loss_type: str = 'kto' desirable_weight: float = 1.0 undesirable_weight: float = 1.0 label_pad_token_id: int = -100 padding_value: typing.Optional[int] = None truncation_mode: str = 'keep_end' generate_during_eval: bool = False is_encoder_decoder: typing.Optional[bool] = None disable_dropout: bool = True precompute_ref_log_probs: bool = False model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None ref_model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None )

参数

  • learning_rate (float, optional, defaults to 1e-6) — AdamW 优化器的初始学习率。默认值替换了 TrainingArguments 中的默认值。
  • max_length (intNone, optional, defaults to 1024) — 批次中序列(prompt + completion)的最大长度。如果您想使用默认数据收集器,则此参数是必需的。
  • max_prompt_length (intNone, optional, defaults to 512) — prompt 的最大长度。如果您想使用默认数据收集器,则此参数是必需的。
  • max_completion_length (intNone, optional, defaults to None) — completion 的最大长度。如果您想使用默认数据收集器,并且您的模型是 encoder-decoder 模型,则此参数是必需的。
  • beta (float, optional, defaults to 0.1) — 控制与参考模型偏差的参数。较高的 β 值意味着与参考模型的偏差较小。
  • loss_type (str, optional, defaults to "kto") — 要使用的损失类型。可能的值为:

    • "kto": 来自 KTO 论文的 KTO 损失。
    • "apo_zero_unpaired": 来自 APO 论文的 APO-zero 损失的非成对变体。
  • desirable_weight (float, optional, defaults to 1.0) — 期望的损失由此因子加权,以抵消期望的和不期望的对的不等数量。
  • undesirable_weight (float, optional, defaults to 1.0) — 不期望的损失由此因子加权,以抵消期望的和不期望的对的不等数量。
  • label_pad_token_id (int, optional, defaults to -100) — 标签填充 token id。如果您想使用默认数据收集器,则此参数是必需的。
  • padding_value (intNone, 可选, 默认为 None) — 要使用的填充值。如果为 None,则使用 tokenizer 的填充值。
  • truncation_mode (str, 可选, 默认为 "keep_end") — 当 prompt 过长时使用的截断模式。可选值为 "keep_end""keep_start"。如果您想使用默认的数据收集器,则此参数是必需的。
  • generate_during_eval (bool, 可选, 默认为 False) — 如果为 True,则在评估期间从模型和参考模型生成完成结果并记录到 W&B 或 Comet。
  • is_encoder_decoder (boolNone, 可选, 默认为 None) — 当使用 model_init 参数 (可调用对象) 来实例化模型而不是 model 参数时,您需要指定可调用对象返回的模型是否为 encoder-decoder 模型。
  • precompute_ref_log_probs (bool, 可选, 默认为 False) — 是否为训练和评估数据集预计算参考模型的对数概率。当在没有参考模型的情况下进行训练以减少所需的 GPU 内存总量时,这很有用。
  • model_init_kwargs (dict[str, Any]None, 可选, 默认为 None) — 从字符串实例化模型时,传递给 AutoModelForCausalLM.from_pretrained 的关键字参数。
  • ref_model_init_kwargs (dict[str, Any]None, 可选, 默认为 None) — 从字符串实例化参考模型时,传递给 AutoModelForCausalLM.from_pretrained 的关键字参数。
  • dataset_num_proc — (intNone, 可选, 默认为 None):用于处理数据集的进程数。
  • disable_dropout (bool, 可选, 默认为 True) — 是否禁用模型和参考模型中的 dropout。

KTOTrainer 的配置类。

使用 HfArgumentParser,我们可以将此类转换为可以在命令行中指定的 argparse 参数。

< > 在 GitHub 上更新