TRL 文档
在线 DPO 训练器
并获得增强的文档体验
开始使用
在线 DPO 训练器
概览
在线 DPO 由 Shangmin Guo、Biao Zhang、Tianlin Liu、Tianqi Liu、Misha Khalman、Felipe Llinares、Alexandre Rame、Thomas Mesnard、Yao Zhao、Bilal Piot、Johan Ferret 和 Mathieu Blondel 在 来自在线 AI 反馈的直接语言模型对齐 中提出。
该论文的摘要如下:
来自偏好的直接对齐 (DAP) 方法,例如 DPO,最近作为从人类反馈中进行强化学习 (RLHF) 的有效替代方案而出现,它们不需要单独的奖励模型。但是,DAP 方法中使用的偏好数据集通常在训练之前收集,并且永远不会更新,因此反馈纯粹是离线的。此外,这些数据集中的响应通常从与正在对齐的语言模型不同的语言模型中采样,并且由于模型在训练过程中不断发展,因此对齐阶段不可避免地是离策略的。在这项研究中,我们认为在线反馈是关键,并且可以改进 DAP 方法。我们的方法,在线 AI 反馈 (OAIF),使用 LLM 作为注释器:在每次训练迭代中,我们从当前模型中采样两个响应,并提示 LLM 注释器选择哪个是首选的,从而提供在线反馈。尽管它很简单,但我们通过在多项任务中的人工评估证明,OAIF 的性能优于离线 DAP 和 RLHF 方法。我们进一步表明,通过向 LLM 注释器发出指令提示,可以轻松控制 OAIF 中利用的反馈。
这种后训练方法由 Michael Noukhovitch、Shengyi Costa Huang、Quentin Gallouédec 和 Edward Beeching 贡献。
快速开始
此示例演示了如何使用在线 DPO 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型,并使用 PairRMJudge 作为评判器。我们使用来自 UltraFeedback 数据集 的提示。您可以在此处查看数据集中的提示
以下是训练模型的脚本
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
使用以下命令执行脚本
accelerate launch train_online_dpo.py
在 8 个 GPU 上分布式训练大约需要 1 小时。您可以通过检查奖励图来验证训练进度。被拒绝和选定完成的奖励均呈上升趋势,这表明模型正在改进并且随着时间的推移生成更好的响应。
要了解 训练后的模型 的性能,您可以使用 Transformers Chat CLI。
$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<quentin_gallouedec>:
What is the best programming language?
<trl-lib/Qwen2-0.5B-OnlineDPO>:
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
预期数据集类型
在线 DPO 仅需要 仅提示数据集(与离线 DPO 不同,离线 DPO 需要 偏好数据集)。OnlineDPOTrainer 支持 对话式 和 标准 数据集格式。当提供对话式数据集时,训练器将自动将聊天模板应用于数据集。
使用技巧
使用奖励模型
您可以选择使用奖励模型而不是评判器——请参阅 Reward Bench,其中包含您可以使用的公共模型的排行榜。以下代码示例展示了如何使用 trl-lib/Qwen2-0.5B-Reward 模型替换评判器
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")
trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_processing_class=reward_tokenizer,
...
)
鼓励 EOS 令牌生成
当使用奖励模型时,我们可能希望模型在给定的长度内生成完成结果。在训练期间,模型将生成完成结果,直到 OnlineDPOConfig 的 max_new_tokens
参数中指定的最大长度。如果您想惩罚模型在达到最大长度之前未生成 EOS 令牌,则可以使用 OnlineDPOConfig 的 missing_eos_penalty
参数
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
记录完成
为了更好地了解模型在训练期间的行为,您可以使用 LogCompletionsCallback 定期记录样本完成情况。
trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
此回调直接将模型生成的完成情况记录到 Weights & Biases。
示例脚本
我们提供了一个示例脚本,用于使用在线 DPO 方法训练模型。该脚本位于 examples/scripts/dpo_online.py
中
要使用 Qwen2.5 0.5B 模型 在 UltraFeedback 数据集 上测试在线 DPO 脚本,请运行以下命令
python examples/scripts/dpo_online.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --judge pair_rm \ --dataset_name trl-lib/ultrafeedback-prompt \ --learning_rate 5.0e-7 \ --logging_steps 25 \ --output_dir Qwen2.5-0.5B-Online-DPO-PairRM \ --warmup_ratio 0.1 \ --push_to_hub
记录的指标
记录的指标如下。这是一个 在 Weights and Biases 上跟踪运行的示例
objective/kl
:当前模型和参考模型之间的平均 Kullback-Leibler (KL) 散度。objective/entropy
:模型的平均熵,表示模型选择的操作的随机性。objective/non_score_reward
:来自非评分相关来源的平均奖励,基本上是beta * kl.sum(1)
,其中beta
是 KL 惩罚系数,kl
是每个令牌的 KL 散度。objective/rlhf_reward
:平均 RLHF 奖励,即scores - non_score_reward
。rlhf_reward
是在线 DPO 训练的最终目标。如果训练按预期进行,则此指标应持续上升。objective/scores
:奖励模型返回的平均分数。objective/scores_margin
:选定完成和拒绝完成之间的平均分数差(根据外部奖励模型)。rewards/chosen
:选定完成的平均奖励(根据在线 DPO 的隐式奖励模型)。rewards/rejected
:拒绝完成的平均奖励(根据在线 DPO 的隐式奖励模型)。rewards/accuracies
:在线 DPO 隐式奖励模型的准确率。rewards/margins
:选定完成和拒绝完成之间的平均奖励差(根据在线 DPO 的隐式奖励模型)。logps/chosen
:选定完成的平均对数概率。logps/rejected
:拒绝完成的平均对数概率。val/contain_eos_token
:包含 EOS 令牌的完成结果的比例。beta
:控制表示与参考模型偏差的损失项权重的参数。通常是固定的,但可以通过将列表传递给 OnlineDPOConfig 来使其动态化。
基准实验
为了验证在线 DPO 实现的有效性,我们使用 Pythia 1B、2.8B 和 6.9B 模型在 8 个 H100 的单节点上运行了实验。以下是我们用于运行实验的命令。我们直接从 RLHF 与 PPO 的 N+ 实现细节:TL;DR 摘要的案例研究 中获取 SFT/RM 模型。
# 1B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 2.8B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 6.9B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--gradient_checkpointing \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
检查点和实验跟踪可在以下位置找到
为了评估,我们使用 vLLM 加载检查点,并使用 GPT-4o mini 作为评判模型来评估生成的 TL;DR 与参考 TL;DR。有关如何使用评判器的更多信息,请参阅 评判器。
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 33.00% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 41.50% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 62.60% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 74.20%
然后我们可以绘制 RLHF 缩放图表。
import matplotlib.pyplot as plt
results = {
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
}
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
plt.xscale("log")
plt.xlabel("Model size")
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
plt.title("DPO scaling by model size")
plt.legend()
plt.xlim(5e8, 1.2e10)
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
plt.grid(True, which="both", ls="--", c="0.7")
plt.tight_layout()
plt.show()
随着我们扩大模型尺寸,在线 DPO 检查点获得了越来越高的胜率。这是一个在线 DPO 实现按预期工作的良好迹象。
OnlineDPOTrainer
class trl.OnlineDPOTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None reward_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None judge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = None args: typing.Optional[trl.trainer.online_dpo_config.OnlineDPOConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, torch.utils.data.dataset.IterableDataset, ForwardRef('datasets.Dataset'), NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], ForwardRef('datasets.Dataset'), NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None reward_processing_class: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
参数
- model (
transformers.PreTrainedModel
或torch.nn.Module
) — 要训练的模型,最好是AutoModelForCausalLM
。 - ref_model (
transformers.PreTrainedModel
或torch.nn.Module
或None
) — 用于训练的参考模型。如果未指定 None,则将从模型创建参考模型。 - reward_model (
transformers.PreTrainedModel
或torch.nn.Module
或None
) — 用于对完成结果进行评分的奖励模型,最好是AutoModelForSequenceClassification
。 - judge (
BasePairwiseJudge
) — 用于模型补全的成对比较的判断器。 - args (
OnlineDPOConfig
) — 用于训练的在线 DPO 配置参数。 - data_collator (
transformers.DataCollator
) — 用于训练的数据整理器。如果未指定,将使用默认数据整理器 (DPODataCollatorWithPadding
),它将在给定成对序列数据集的情况下,将序列填充到批次中最长序列的长度。 - train_dataset (
datasets.Dataset
) — 用于训练的数据集。 - eval_dataset (
datasets.Dataset
) — 用于评估的数据集。 - processing_class (
PreTrainedTokenizerBase
orBaseImageProcessor
orFeatureExtractionMixin
orProcessorMixin
, 可选) — 用于处理数据的处理类。如果提供,将用于自动处理模型的输入,并将其与模型一起保存,以便更容易地重新运行中断的训练或重用微调后的模型。 - peft_config (
dict
) — 用于训练的 peft 配置。 - compute_metrics (
Callable[[EvalPrediction], dict]
, 可选) — 用于计算指标的函数。必须接受一个EvalPrediction
并返回一个字典,其中字符串映射到指标值。 - callbacks (
list[transformers.TrainerCallback]
) — 用于训练的回调列表。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用于训练的优化器和调度器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 用于在计算指标之前预处理 logits 的函数。
初始化 OnlineDPOTrainer。
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
使用 Trainer
可用的信息创建模型卡的草稿。
从 DPO 特定数据集中 Tokenize 单行。
OnlineDPOConfig
class trl.OnlineDPOConfig
< source >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False reward_model_path: typing.Optional[str] = None judge: typing.Optional[str] = None max_new_tokens: int = 64 max_length: int = 512 temperature: float = 0.9 missing_eos_penalty: typing.Optional[float] = None beta: list = <factory> loss_type: str = 'sigmoid' dataset_num_proc: typing.Optional[int] = None disable_dropout: bool = True use_vllm: bool = False ds3_gather_for_generation: bool = True )
参数
- learning_rate (
float
, optional, defaults to5e-7
) —AdamW
优化器的初始学习率。默认值替换了TrainingArguments
的默认值。 - reward_model_path (
str
orNone
, optional, defaults toNone
) — 奖励模型的路径。必须设置judge
或reward_model_path
,但不能同时设置两者。 - judge (
str
orNone
, optional, defaults toNone
) — 要使用的评判器的名称。必须设置judge
或reward_model_path
,但不能同时设置两者。 - max_new_tokens (
int
, optional, defaults to64
) — 每次完成生成最多的 token 数量。 - max_length (
int
, optional, defaults to256
) — 用于计算对数概率的序列(提示 + 完成)的最大总长度。如果序列超过此限制,将截断最左边的 token,以尽可能保留完成部分。 - temperature (
float
, optional, defaults to0.9
) — 采样的温度。温度越高,完成结果越随机。 - missing_eos_penalty (
float
或None
, 可选, 默认为None
) — 当模型未能生成 EOS token 时,应用于分数的惩罚。这有助于鼓励生成短于最大长度 (max_new_tokens
) 的补全内容。惩罚值必须为正数。 - beta (
float
或list[float]
, 可选, 默认为0.1
) — 控制与参考模型偏差的参数。β 值越高,表示与参考模型的偏差越小。对于 IPO 损失 (loss_type="ipo"
),β 是 论文 中用 τ 表示的正则化参数。如果提供浮点数列表,则为每个新 epoch 选择一个 β 值,最后一个 β 值用于其余的 epoch。 - loss_type (
str
, 可选, 默认为"sigmoid"
) — 要使用的损失类型。可能的值包括: - dataset_num_proc (
int
或None
, 可选, 默认为None
) — 用于处理数据集的进程数。 - disable_dropout (
bool
, 可选, 默认为True
) — 是否禁用模型和参考模型中的 dropout。 - use_vllm (
bool
, 可选, 默认为False
) — 是否使用 vLLM 生成补全内容。需要安装 vLLM (pip install vllm
)。 - ds3_gather_for_generation (
bool
, 可选, 默认为True
) — 此设置适用于 DeepSpeed ZeRO-3。如果启用,策略模型权重将被收集用于生成,从而提高生成速度。但是,禁用此选项允许训练超出单个 GPU 的 VRAM 容量的模型,尽管会以较慢的生成速度为代价。
用于 OnlineDPOTrainer 的配置类。
使用 HfArgumentParser
,我们可以将此类转换为 argparse 参数,这些参数可以在命令行中指定。