TRL 文档

回调函数

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

回调函数

SyncRefModelCallback

trl.SyncRefModelCallback

< >

( ref_model: Union accelerator: 可选 )

RichProgressCallback

trl.RichProgressCallback

< >

( )

一个使用 Rich 显示训练或评估进度的 TrainerCallback

WinRateCallback

trl.WinRateCallback

< >

( judge: BasePairwiseJudge trainer: Trainer generation_config: Optional = None num_prompts: int = None )

参数

  • judge (BasePairwiseJudge) — 用于比较补全结果的评判器。
  • trainer (Trainer) — 将回调附加到的训练器。训练器的评估数据集必须包含一个名为 "prompt" 的列,其中包含用于生成补全结果的提示。如果 Trainer 有一个参考模型(通过 ref_model 属性),它将使用此参考模型来生成参考补全结果;否则,它默认为使用初始模型。
  • generation_config (GenerationConfig, 可选) — 用于生成补全结果的生成配置。
  • num_prompts (int, 可选) — 要为其生成补全结果的提示数量。如果未提供,则默认为评估数据集中示例的数量。

一个基于参考计算模型胜率的 TrainerCallback

它使用来自评估数据集的提示生成补全结果,并将训练模型的输出与参考进行比较。参考是模型的初始版本(训练前)或参考模型(如果训练器中可用)。在每个评估步骤中,评判器会使用评判器来确定训练模型的补全结果与参考相比获胜的频率。然后,胜率将记录在训练器的日志中,键名为 "eval_win_rate"

用法

trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)

LogCompletionsCallback

trl.LogCompletionsCallback

< >

( trainer: Trainer generation_config: Optional = None num_prompts: int = None freq: int = None )

参数

  • generation_config (GenerationConfig可选) — 用于生成补全的生成配置。
  • num_prompts (int可选) — 要为其生成补全的提示数量。如果未提供,则默认为评估数据集中的示例数量。
  • freq (int可选) — 记录补全的频率。如果未提供,则默认为训练器的 eval_steps

一个将补全记录到 Weights & Biases 的 TrainerCallback

用法

trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)
< > 在 GitHub 上更新