TRL 文档
回调
并获取增强的文档体验
开始使用
回调
SyncRefModelCallback
class trl.SyncRefModelCallback
< source >( ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] accelerator: typing.Optional[accelerate.accelerator.Accelerator] )
用于同步模型和参考模型的回调。
RichProgressCallback
一个 TrainerCallback
,它使用 Rich 显示训练或评估的进度。
WinRateCallback
class trl.WinRateCallback
< source >( judge: BasePairwiseJudge trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None shuffle_order: bool = True use_soft_judge: bool = False )
参数
- judge (
BasePairwiseJudge
) — 用于比较补全结果的评判器。 - trainer (
Trainer
) — 回调将附加到的训练器。训练器的评估数据集必须包含一个"prompt"
列,其中包含用于生成补全的提示。如果Trainer
有一个参考模型(通过ref_model
属性),它将使用此参考模型来生成参考补全;否则,它默认使用初始模型。 - generation_config (
GenerationConfig
, 可选) — 用于生成补全的生成配置。 - num_prompts (
int
或None
, 可选, 默认为None
) — 用于生成补全的提示数量。如果未提供,则默认为评估数据集中的示例数。 - shuffle_order (
bool
, 可选, 默认为True
) — 是否在评判前打乱补全顺序。 - use_soft_judge (
bool
, 可选, 默认为False
) — 是否使用软评判器,该评判器返回第一个补全相对于第二个补全的胜率概率,介于 0 和 1 之间。
一个 TrainerCallback
,用于计算模型基于参考的胜率。
它使用评估数据集中的提示生成补全,并将训练模型的输出与参考进行比较。参考可以是模型的初始版本(训练前)或参考模型(如果在训练器中可用)。在每个评估步骤中,评判器确定使用评判器时,训练模型的补全胜过参考的频率。然后,胜率将记录在训练器的日志中,键为 "eval_win_rate"
。
LogCompletionsCallback
class trl.LogCompletionsCallback
< source >( trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None freq: typing.Optional[int] = None )
一个 TrainerCallback
,它将补全记录到 Weights & Biases 和/或 Comet。
MergeModelCallback
class trl.MergeModelCallback
< source >( merge_config: typing.Optional[ForwardRef('MergeConfig')] = None merge_at_every_checkpoint: bool = False push_to_hub: bool = False )
一个 TrainerCallback
,它基于合并配置将策略模型(正在训练的模型)与另一个模型合并。