在线 DPO 训练器
概述
在线 DPO 在 从在线人工智能反馈中直接语言模型对齐 中由 Shangmin Guo、Biao Zhang、Tianlin Liu、Tianqi Liu、Misha Khalman、Felipe Llinares、Alexandre Rame、Thomas Mesnard、Yao Zhao、Bilal Piot、Johan Ferret 和 Mathieu Blondel 提出。
论文摘要如下
从偏好中直接对齐 (DAP) 方法,例如 DPO,最近成为对来自人类反馈的强化学习 (RLHF) 的有效替代方案,它们不需要单独的奖励模型。然而,DAP 方法中使用的偏好数据集通常是在训练之前收集的,并且从未更新,因此反馈纯粹是离线的。此外,这些数据集中的响应通常从与被对齐的模型不同的语言模型中采样,并且由于模型在训练过程中不断发展,对齐阶段不可避免地是脱策略的。在本研究中,我们认为在线反馈是关键,并且可以改进 DAP 方法。我们的方法,在线人工智能反馈 (OAIF),使用 LLM 作为注释器:在每次训练迭代中,我们从当前模型中采样两个响应,并提示 LLM 注释器选择哪个响应更受欢迎,从而提供在线反馈。尽管非常简单,但我们通过在多个任务中进行的人工评估证明 OAIF 优于离线 DAP 和 RLHF 方法。我们进一步表明,OAIF 中利用的反馈很容易通过对 LLM 注释器的指令提示进行控制。
当前实现使用奖励模型来对完成情况进行评分 - 请参阅 奖励基准 以了解您可以使用的公开模型排行榜。
这种训练后方法由 Michael Noukhovitch、Shengyi Costa Huang、Quentin Gallouédec 和 Edward Beeching 贡献。
快速入门
本示例演示了如何使用在线 DPO 方法训练模型。我们使用 Qwen 0.5B 模型 作为基础模型,使用 Qwen 0.5B 奖励模型 作为奖励模型。我们使用 UltraFeedback 数据集 中的提示。您可以在此处查看数据集中的提示
以下是训练模型的脚本
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
使用以下命令执行脚本
accelerate launch train_online_dpo.py
在 8 个 GPU 上分布式运行,训练大约需要 1 个小时。您可以通过检查奖励图来验证训练进度。拒绝和选定完成的奖励都呈上升趋势表明模型正在改进,并随着时间的推移生成更好的响应。
要查看训练后的模型的性能,请使用以下代码生成完成
>>> from transformers import pipeline
>>> generator = pipeline("text-generation", model="online-dpo-qwen2/checkpoint-1773", device="cuda")
>>> question = "Why is the problem always DNS?"
>>> output = generator([{"role": "user", "content": question}], max_new_tokens=200, return_full_text=False)[0]
>>> print(output["generated_text"])
The reason why the problem of DNS (Domain Name System) can always be encountered is that it is designed to provide reliable and accurate information about the availability, ownership, or expiration of domain names. However, there may be some circumstances where the system fails to resolve an IP address correctly, leading to the problem of DNS.
For example, if the server hosting the domain name does not have the correct IP address associated with it, or if the IP address is incorrectly formatted, then the DNS system will fail to resolve the domain name correctly. Additionally, if the server hosting the domain name has been compromised, then the DNS system may also fail to resolve the domain name correctly.
It's worth noting that the exact cause of DNS failure can vary depending on the specific situation, so it's important to carefully check all relevant factors before attempting to resolve the issue. If you suspect that your DNS problem may be caused by a bug in the system, you should report it to the DNS provider directly for further investigation.
预期数据集格式
在线 DPO 只需要一个 仅提示数据集(与离线 DPO 不同,离线 DPO 需要一个 偏好数据集)。OnlineDPOTrainer 支持 对话 和 标准 数据集格式。当提供对话数据集时,训练器会自动将聊天模板应用于数据集。
使用技巧
⚠️ 使用相同的聊天模板
确保 SFT 模型和奖励模型使用相同的聊天模板。否则,您可能会发现模型完成在训练期间被错误地评分。
鼓励生成 EOS 符号
我们希望模型在给定长度内生成完成。在学习过程中,模型将生成完成,直到在 OnlineDPOConfig 的 max_new_tokens
参数中指定的最大完成长度。如果您想在达到最大完成长度之前没有生成 EOS 符号时进行处罚,您可以使用 OnlineDPOConfig 的 missing_eos_penalty
参数。
args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
记录完成
为了更好地了解模型在训练过程中的行为,您可以使用 LogCompletionsCallback 定期记录样本完成。
trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
此回调将模型生成的完成直接记录到 Weights & Biases。
示例脚本
我们提供了一个示例脚本,使用在线 DPO 方法训练模型。该脚本在 examples/scripts/dpo_online.py
中提供。
要使用 Pythia 1B 模型 在 TL;DR 摘要任务上测试在线 DPO 脚本,请运行以下命令
python examples/scripts/dpo_online.py \ --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ --dataset_name trl-lib/tldr \ --learning_rate 5.0e-7 \ --output_dir pythia-1b-tldr-online-dpo \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 32 \ --num_train_epochs 3 \ --max_new_tokens 53 \ --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 \ --push_to_hub
记录的指标
记录的指标如下。这是一个示例 在 Weights and Biases 上跟踪的运行
objective/kl
:当前模型与参考模型之间的平均 Kullback-Leibler (KL) 散度。objective/entropy
:模型的平均熵,表示模型选择的动作的随机性。objective/non_score_reward
:来自非分数相关来源的平均奖励,基本上是beta * kl.sum(1)
,其中beta
是 KL 惩罚系数,kl
是每个符号的 KL 散度。objective/rlhf_reward
:平均 RLHF 奖励,即scores - non_score_reward
。rlhf_reward
是在线 DPO 训练的最终目标。如果训练按预期进行,此指标应该会持续上升。objective/scores
:奖励模式返回的平均分数。objective/scores_margin
:选定完成与拒绝完成之间的平均分数差(根据外部奖励模型)。rewards/chosen
:选定完成的平均奖励(根据在线 DPO 的隐式奖励模型)。rewards/rejected
:拒绝完成的平均奖励(根据在线 DPO 的隐式奖励模型)。rewards/accuracies
:在线 DPO 的隐式奖励模型的准确率。rewards/margins
:选定完成与拒绝完成之间的平均奖励差(根据在线 DPO 的隐式奖励模型)。logps/chosen
:选定完成的平均对数概率。logps/rejected
:拒绝完成的平均对数概率。val/contain_eos_token
:包含 EOS 符号的完成的比例。beta
:控制表示与参考模型偏差的损失项的权重的参数。通常是固定的,但可以通过将列表传递给 OnlineDPOConfig 来使其动态化。
基准实验
为了验证在线 DPO 实现有效,我们在 8 x H100s 的单个节点上使用 Pythia 1B、2.8B 和 6.9B 模型进行了实验。以下是我们用来运行实验的命令。我们直接从 RLHF 与 PPO 的 N+ 实现细节:以 TL;DR 摘要为例 中获取 SFT / RM 模型。
# 1B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 2.8B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 6.9B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--gradient_checkpointing \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
检查点和实验跟踪可在以下位置找到:
为了进行评估,我们使用 vLLM 加载检查点,并使用 GPT-4o mini 作为评判模型,评估生成的 TL;DR 与参考 TL;DR 的一致性。有关如何使用评判模型的更多信息,请参见 评判模型。
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 33.00% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 41.50% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 62.60% python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 74.20%
然后我们可以绘制 RLHF 扩展图表。
import matplotlib.pyplot as plt
results = {
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
}
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
plt.xscale("log")
plt.xlabel("Model size")
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
plt.title("DPO scaling by model size")
plt.legend()
plt.xlim(5e8, 1.2e10)
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
plt.grid(True, which="both", ls="--", c="0.7")
plt.tight_layout()
plt.show()
随着模型规模的增加,在线 DPO 检查点的胜率越来越高。这是一个好兆头,表明在线 DPO 实现按预期工作。
OnlineDPOTrainer
class trl.OnlineDPOTrainer
< source >( model: Union ref_model: Union = None reward_model: Union = None judge: Optional = None args: Optional = None data_collator: Optional = None train_dataset: Union = None eval_dataset: Union = None tokenizer: Optional = None peft_config: Optional = None compute_metrics: Optional = None callbacks: Optional = None optimizers: Tuple = (None, None) preprocess_logits_for_metrics: Optional = None )
参数
- model (
transformers.PreTrainedModel
或torch.nn.Module
) — 要训练的模型,最好是AutoModelForCausalLM
。 - ref_model (
transformers.PreTrainedModel
或torch.nn.Module
或None
) — 用于训练的参考模型。如果指定 None,则将从模型创建参考模型。 - reward_model (
transformers.PreTrainedModel
或torch.nn.Module
或None
) — 用于对完成进行评分的奖励模型,最好是AutoModelForSequenceClassification
。 - judge (
BasePairwiseJudge
) — 用于模型完成的成对比较的评判模型。 - args (
OnlineDPOConfig
) — 用于训练的在线 DPO 配置参数。 - data_collator (
transformers.DataCollator
) — 用于训练的数据整理器。如果指定 None,则将使用默认数据整理器 (DPODataCollatorWithPadding
),它将根据配对序列数据集,将序列填充到批次中序列的最大长度。 - train_dataset (
datasets.Dataset
) — 用于训练的数据集。 - eval_dataset (
datasets.Dataset
) — 用于评估的数据集。 - tokenizer (
transformers.PreTrainedTokenizerBase
) — 用于训练的词语切分器。 如果您想使用默认的数据整理器,则此参数是必需的。 - peft_config (
Dict
) — 用于训练的 peft 配置。 - compute_metrics (
Callable[[EvalPrediction], Dict]
, 可选) — 用于计算指标的函数。 必须接收EvalPrediction
并返回一个字典,其中包含指标的值。 - callbacks (
List[transformers.TrainerCallback]
) — 用于训练的回调函数。 - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用于训练的优化器和调度器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 用于在计算指标之前预处理 logits 的函数。
初始化 OnlineDPOTrainer。
对来自特定于 DPO 数据集的单个行进行词语切分。
OnlineDPOConfig
用于 OnlineDPOTrainer 的配置类。
使用 HfArgumentParser,我们可以将此类转换为 argparse 参数,这些参数可以在命令行中指定。