TRL 文档

回调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

回调函数

SyncRefModelCallback

class trl.SyncRefModelCallback

< >

( ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] accelerator: typing.Optional[accelerate.accelerator.Accelerator] )

用于将模型与参考模型同步的回调函数。

RichProgressCallback

class trl.RichProgressCallback

< >

( )

一个使用 Rich 显示训练或评估进度的 `TrainerCallback`。

WinRateCallback

class trl.WinRateCallback

< >

( judge: BasePairwiseJudge trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None shuffle_order: bool = True use_soft_judge: bool = False )

参数

  • judge (BasePairwiseJudge) — 用于比较生成内容的评判器。
  • trainer (Trainer) — 回调函数将附加到的训练器。训练器的评估数据集必须包含一个名为 `"prompt"` 的列,其中包含用于生成内容的提示。如果 `Trainer` 有一个参考模型(通过 `ref_model` 属性),它将使用此参考模型生成参考内容;否则,它将默认使用初始模型。
  • generation_config (GenerationConfig, 可选) — 用于生成内容的生成配置。
  • num_prompts (intNone, 可选, 默认为 None) — 要为其生成内容的提示数量。如果未提供,则默认为评估数据集中的示例数量。
  • shuffle_order (bool, 可选, 默认为 True) — 是否在评判前打乱生成内容的顺序。
  • use_soft_judge (bool, 可选, 默认为 False) — 是否使用一个软评判器,它为第一个生成内容与第二个生成内容的对比返回一个 0 到 1 之间的获胜概率。

一个 `TrainerCallback`,用于根据参考模型计算模型的胜率。

它使用来自评估数据集的提示生成内容,并将训练好的模型的输出与参考模型进行比较。参考模型要么是模型的初始版本(训练前),要么是训练器中可用的参考模型。在每个评估步骤中,评判器会确定训练好的模型生成的内容相比参考模型获胜的频率。然后,胜率会记录在训练器的日志中,键为 `"eval_win_rate"`。

用法

trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)

LogCompletionsCallback

class trl.LogCompletionsCallback

< >

( trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None freq: typing.Optional[int] = None )

参数

  • trainer (Trainer) — 回调函数将附加到的训练器。训练器的评估数据集必须包含一个名为 `"prompt"` 的列,其中包含用于生成内容的提示。
  • generation_config (GenerationConfig, 可选) — 用于生成内容的生成配置。
  • num_prompts (intNone, 可选) — 要为其生成内容的提示数量。如果未提供,则默认为评估数据集中的示例数量。
  • freq (intNone, 可选) — 记录生成内容的频率。如果未提供,则默认为训练器的 `eval_steps`。

一个 `TrainerCallback`,用于将生成的内容记录到 Weights & Biases 和/或 Comet 中。

用法

trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)

MergeModelCallback

class trl.MergeModelCallback

< >

( merge_config: typing.Optional[ForwardRef('MergeConfig')] = None merge_at_every_checkpoint: bool = False push_to_hub: bool = False )

参数

  • merge_config (MergeConfig, 可选, 默认为 None) — 用于合并过程的配置。如果未提供,则使用默认的 `MergeConfig`。
  • merge_at_every_checkpoint (bool, 可选, 默认为 False) — 是否在每个检查点合并模型。
  • push_to_hub (bool, 可选, 默认为 False) — 合并后是否将合并后的模型推送到 Hub。

一个 `TrainerCallback`,用于根据合并配置将策略模型(正在训练的模型)与另一个模型进行合并。

示例

# pip install mergekit

from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
< > 在 GitHub 上更新