Transformers 文档

回调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

回调函数

回调函数是能够自定义 PyTorch Trainer 中训练循环行为的对象(此功能尚未在 TensorFlow 中实现),它可以检查训练循环状态(用于进度报告、在 TensorBoard 或其他 ML 平台上记录日志……)并做出决策(如提前停止)。

回调函数是“只读”的代码片段,除了它们返回的 TrainerControl 对象外,它们不能改变训练循环中的任何内容。对于需要改变训练循环的自定义,您应该子类化 Trainer 并覆盖您需要的方法(参阅 训练器 以获取示例)。

默认情况下,`TrainingArguments.report_to` 设置为 `"all"`,因此 Trainer 将使用以下回调函数。

如果已安装某个包,但您不想使用随附的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您要使用的集成的列表(例如,`["azure_ml", "wandb"]`)。

实现回调函数的主要类是 TrainerCallback。它获取用于实例化 TrainerTrainingArguments,可以通过 TrainerState 访问该 Trainer 的内部状态,并且可以通过 TrainerControl 对训练循环执行一些操作。

可用回调函数

以下是库中可用的 TrainerCallback 列表

class transformers.integrations.CometCallback

< >

( )

一个将日志发送到 Comet MLTrainerCallback

设置

< >

( args state model )

设置可选的 Comet 集成。

环境

  • COMET_MODE (str, 可选,默认为 get_or_create):控制是创建并记录到新的 Comet 实验还是附加到现有实验。它接受以下值
    • get_or_create:根据是否设置了 `COMET_EXPERIMENT_KEY` 以及该键的实验是否存在自动决定。
    • create:始终创建一个新的 Comet 实验。
    • get:始终尝试附加到现有 Comet 实验。需要设置 `COMET_EXPERIMENT_KEY`。
    • ONLINE:**已弃用**,用于创建在线实验。请改用 `COMET_START_ONLINE=1`。
    • OFFLINE:**已弃用**,用于创建离线实验。请改用 `COMET_START_ONLINE=0`。
    • DISABLED:**已弃用**,用于禁用 Comet 日志记录。请改用 `--report_to` 标志来控制用于记录结果的集成。
  • COMET_PROJECT_NAME (str, 可选):Comet 实验的项目名称。
  • COMET_LOG_ASSETS (str, 可选,默认为 TRUE):是否将训练资产(tf 事件日志、检查点等)记录到 Comet。可以是 TRUEFALSE

有关环境中可配置项的数量,请参阅此处

class transformers.DefaultFlowCallback

< >

( )

一个 TrainerCallback,用于处理日志、评估和检查点的默认训练循环流。

class transformers.PrinterCallback

< >

( )

一个只打印日志的 TrainerCallback

class transformers.ProgressCallback

< >

( max_str_len: int = 100 )

一个 TrainerCallback,显示训练或评估的进度。您可以修改 `max_str_len` 来控制日志记录时字符串截断的长度。

class transformers.EarlyStoppingCallback

< >

( early_stopping_patience: int = 1 early_stopping_threshold: typing.Optional[float] = 0.0 )

参数

  • early_stopping_patience (int) — 与 metric_for_best_model 一起使用,当指定指标在 `early_stopping_patience` 次评估调用后恶化时停止训练。
  • early_stopping_threshold(float, 可选) — 与 TrainingArguments `metric_for_best_model` 和 `early_stopping_patience` 一起使用,表示指定指标必须提高多少才能满足提前停止条件。`

一个处理提前停止的 TrainerCallback

此回调函数依赖于 TrainingArguments 参数 *load_best_model_at_end* 功能,用于设置 TrainerState 中的 best_metric。请注意,如果 TrainingArguments 参数 *save_steps* 与 *eval_steps* 不同,则提前停止不会发生,直到下一个保存步骤。

class transformers.integrations.TensorBoardCallback

< >

( tb_writer = None )

参数

  • tb_writer (SummaryWriter, 可选) — 要使用的写入器。如果未设置,将实例化一个新的。

一个将日志发送到 TensorBoardTrainerCallback

class transformers.integrations.WandbCallback

< >

( )

一个将指标、媒体、模型检查点记录到 Weight and BiasesTrainerCallback

设置

< >

( args state model **kwargs )

设置可选的 Weights & Biases (wandb) 集成。

可以子类化并重写此方法以根据需要自定义设置。欲了解更多信息,请参阅此处。您还可以覆盖以下环境变量

环境

  • WANDB_LOG_MODEL (str, 可选, 默认为 "false"):是否在训练期间记录模型和检查点。可以是 "end", "checkpoint""false"。如果设置为 "end",模型将在训练结束时上传。如果设置为 "checkpoint",检查点将每 args.save_steps 上传一次。如果设置为 "false",模型将不会上传。与 load_best_model_at_end() 一起使用以上传最佳模型。

    5.0 版本中已弃用

    在 🤗 Transformers 的 5.0 版本中,将 WANDB_LOG_MODEL 设置为 bool 将被弃用。

  • WANDB_WATCH (str, 可选,默认为 "false"):可以是 "gradients", "all", "parameters""false"。设置为 "all" 以记录梯度和参数。

  • WANDB_PROJECT (str, 可选,默认为 "huggingface"):将其设置为自定义字符串以将结果存储在不同的项目中。

  • WANDB_DISABLED (bool, 可选,默认为 False):是否完全禁用 wandb。设置为 `WANDB_DISABLED=true` 以禁用。

class transformers.integrations.MLflowCallback

< >

( )

一个将日志发送到 MLflowTrainerCallback。可以通过设置环境变量 `DISABLE_MLFLOW_INTEGRATION = TRUE` 来禁用。

设置

< >

( args state model )

设置可选的 MLflow 集成。

环境

  • HF_MLFLOW_LOG_ARTIFACTS (str, 可选):是否使用 MLflow `log_artifact()` 功能记录工件。这仅在记录到远程服务器(例如 s3 或 GCS)时才有意义。如果设置为 `True` 或 *1*,则在 TrainingArguments 的 `output_dir` 中每次保存时将每个保存的检查点复制到本地或远程工件存储。在没有远程存储的情况下使用它只会将文件复制到您的工件位置。
  • MLFLOW_TRACKING_URI (str, 可选):是否将运行存储在特定路径或远程服务器。默认情况下未设置,这将完全跳过设置跟踪 URI。
  • MLFLOW_EXPERIMENT_NAME (str, 可选, 默认为 None):是否使用 MLflow 实验名称来启动运行。默认为 None,这将指向 MLflow 中的 `Default` 实验。否则,它是要激活的实验的区分大小写名称。如果不存在具有此名称的实验,则会创建一个具有此名称的新实验。
  • MLFLOW_TAGS (str, 可选):键/值对字典的字符串转储,将作为标签添加到 MLflow 运行中。示例:`os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'`。
  • MLFLOW_NESTED_RUN (str, 可选):是否使用 MLflow 嵌套运行。如果设置为 True 或 *1*,将在当前运行中创建一个嵌套运行。
  • MLFLOW_RUN_ID (str, 可选):允许重新连接到现有运行,这在从检查点恢复训练时很有用。当设置了 `MLFLOW_RUN_ID` 环境变量时,`start_run` 尝试恢复具有指定运行 ID 的运行,其他参数将被忽略。
  • MLFLOW_FLATTEN_PARAMS (str, 可选,默认为 False):是否在记录之前展平参数字典。
  • MLFLOW_MAX_LOG_PARAMS (int, 可选):设置在运行中记录的最大参数数量。

class transformers.integrations.AzureMLCallback

< >

( azureml_run = None )

一个将日志发送到 AzureMLTrainerCallback

class transformers.integrations.CodeCarbonCallback

< >

( )

一个追踪训练过程中二氧化碳排放量的 TrainerCallback

class transformers.integrations.NeptuneCallback

< >

( api_token: typing.Optional[str] = None project: typing.Optional[str] = None name: typing.Optional[str] = None base_namespace: str = 'finetuning' run = None log_parameters: bool = True log_checkpoints: typing.Optional[str] = None **neptune_run_kwargs )

参数

  • api_token (str, 可选) — 注册时获取的 Neptune API 令牌。如果您已将令牌保存到 `NEPTUNE_API_TOKEN` 环境变量中(强烈推荐),则可以省略此参数。请参阅文档中的完整设置说明。
  • project (str, 可选) — 现有 Neptune 项目的名称,格式为“workspace-name/project-name”。您可以在 Neptune 的项目设置 -> 属性中找到并复制该名称。如果为 None(默认),则使用 `NEPTUNE_PROJECT` 环境变量的值。
  • name (str, 可选) — 运行的自定义名称。
  • base_namespace (str, 可选, 默认为“finetuning”) — 在 Neptune 运行中,将包含回调函数记录的所有元数据的根命名空间。
  • log_parameters (bool, 可选, 默认为 True) — 如果为 True,则记录 Trainer 提供的所有 Trainer 参数和模型参数。
  • log_checkpoints (str, 可选) — 如果为“same”,则在 Trainer 保存检查点时上传检查点。如果为“last”,则仅上传最近保存的检查点。如果为“best”,则上传最佳检查点(Trainer 保存的检查点中)。如果为 `None`,则不上传检查点。
  • run (Run, 可选) — 如果您想继续记录到现有运行,请传入一个 Neptune 运行对象。有关恢复运行的更多信息,请参阅文档
  • **neptune_run_kwargs (可选) — 在创建新运行时直接传递给 neptune.init_run() 函数的其他关键字参数。

将日志发送到 Neptune 的 TrainerCallback。

有关说明和示例,请参阅 Neptune 文档中的Transformers 集成指南

class transformers.integrations.ClearMLCallback

< >

( )

一个将日志发送到 ClearMLTrainerCallback

环境

  • CLEARML_PROJECT (str, 可选,默认为 HuggingFace Transformers):ClearML 项目名称。
  • CLEARML_TASK (str, 可选,默认为 Trainer):ClearML 任务名称。
  • CLEARML_LOG_MODEL (bool, 可选,默认为 False):是否在训练期间将模型记录为工件。

class transformers.integrations.DagsHubCallback

< >

( )

一个将日志记录到 DagsHubTrainerCallback。继承自 `MLflowCallback`

设置

< >

( *args **kwargs )

设置 DagsHub 的日志集成。

环境

  • HF_DAGSHUB_LOG_ARTIFACTS (str, 可选):是否保存实验的数据和模型工件。默认为 `False`。

class transformers.integrations.FlyteCallback

< >

( save_log_history: bool = True sync_checkpoints: bool = True )

参数

  • save_log_history (bool, 可选, 默认为 True) — 如果设置为 True,训练日志将作为 Flyte Deck 保存。
  • sync_checkpoints (bool, 可选, 默认为 True) — 如果设置为 True,检查点将与 Flyte 同步,并且可以在中断时用于恢复训练。

一个将日志发送到 FlyteTrainerCallback。注意:此回调函数仅在 Flyte 任务中有效。

示例

# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task


@task
def train_hf_transformer():
    cp = current_context().checkpoint
    trainer = Trainer(..., callbacks=[FlyteCallback()])
    output = trainer.train(resume_from_checkpoint=cp.restore())

class transformers.integrations.DVCLiveCallback

< >

( live: typing.Optional[typing.Any] = None log_model: typing.Union[typing.Literal['all'], bool, NoneType] = None **kwargs )

参数

  • live (dvclive.Live, 可选, 默认为 None) — 可选的 Live 实例。如果为 None,将使用 **kwargs 创建一个新实例。
  • log_model (Union[Literal[“all”], bool], 可选, 默认为 None) — 是否使用 dvclive.Live.log_artifact() 记录由 Trainer 创建的检查点。如果设置为 True,最终检查点将在训练结束时记录。如果设置为 "all",整个 TrainingArguments 的 `output_dir` 将在每个检查点记录。

一个将日志发送到 DVCLiveTrainerCallback

在 `setup` 中使用以下环境变量配置集成。要在此环境变量之外自定义此回调函数,请参阅此处

设置

< >

( args state model )

设置可选的 DVCLive 集成。要在此环境变量之外自定义此回调函数,请参阅此处

环境

  • HF_DVCLIVE_LOG_MODEL (str, 可选):是否使用 `dvclive.Live.log_artifact()` 记录由 Trainer 创建的检查点。如果设置为 `True` 或 *1*,最终检查点将在训练结束时记录。如果设置为 `all`,整个 TrainingArguments 的 `output_dir` 将在每个检查点记录。

class transformers.integrations.SwanLabCallback

< >

( )

一个将指标、媒体、模型检查点记录到 SwanLabTrainerCallback

设置

< >

( args state model **kwargs )

设置可选的 SwanLab (swanlab) 集成。

如有需要,可以子类化并覆盖此方法以自定义设置。更多信息请参阅此处

您还可以覆盖以下环境变量。更多关于环境变量的信息请参阅此处

环境

  • SWANLAB_API_KEY (str, 可选, 默认为 None): 云API密钥。登录时,首先检查此环境变量。如果不存在,系统会检查用户是否已登录。如果未登录,则启动登录过程。

    • 如果将字符串传递给登录接口,则忽略此环境变量。
    • 如果用户已登录,此环境变量优先于本地存储的登录信息。
  • SWANLAB_PROJECT (str, 可选, 默认为 None): 将此设置为自定义字符串,以便将结果存储在不同的项目中。如果未指定,则使用当前运行目录的名称。

  • SWANLAB_LOG_DIR (str, 可选, 默认为 swanlog): 此环境变量指定在本地模式下运行时的日志文件存储路径。默认情况下,日志保存在工作目录下名为 swanlog 的文件夹中。

  • SWANLAB_MODE (Literal["local", "cloud", "disabled"], 可选, 默认为 cloud): SwanLab 的解析模式,涉及操作员注册的回调。目前有三种模式:本地、云和禁用。注意:区分大小写。更多信息请参阅此处

  • SWANLAB_LOG_MODEL (str, 可选, 默认为 None): SwanLab 目前不支持保存模式功能。此功能将在未来版本中提供。

  • SWANLAB_WEB_HOST (str, 可选, 默认为 None): 私有版本 SwanLab 云环境的 Web 地址(免费)

  • SWANLAB_API_HOST (str, 可选, 默认为 None): 私有版本 SwanLab 云环境的 API 地址(免费)

TrainerCallback

class transformers.TrainerCallback

< >

( )

参数

  • args (TrainingArguments) — 用于实例化 Trainer 的训练参数。
  • state (TrainerState) — Trainer 的当前状态。
  • control (TrainerControl) — 返回给 Trainer 的对象,可用于做出某些决策。
  • model (PreTrainedModeltorch.nn.Module) — 正在训练的模型。
  • tokenizer (PreTrainedTokenizer) — 用于数据编码的分词器。此参数已被弃用,建议使用 processing_class
  • processing_class ([PreTrainedTokenizerBaseImageProcessorProcessorMixinFeatureExtractionMixin]) — 用于数据编码的处理类。可以是分词器、处理器、图像处理器或特征提取器。
  • optimizer (torch.optim.Optimizer) — 用于训练步骤的优化器。
  • lr_scheduler (torch.optim.lr_scheduler.LambdaLR) — 用于设置学习率的调度器。
  • train_dataloader (torch.utils.data.DataLoader, 可选) — 当前用于训练的数据加载器。
  • eval_dataloader (torch.utils.data.DataLoader, 可选) — 当前用于评估的数据加载器。
  • metrics (dict[str, float]) — 上一评估阶段计算的指标。

    这些仅在 on_evaluate 事件中可访问。

  • logs (dict[str, float]) — 要记录的值。

    这些仅在 on_log 事件中可访问。

一个类,用于在某些事件中检查训练循环状态并做出决策的对象。在每个这些事件中,以下参数可用:

control 对象是唯一可以被回调更改的对象,在这种情况下,更改它的事件应返回修改后的版本。

参数 argsstatecontrol 是所有事件的位置参数,所有其他参数都分组在 kwargs 中。您可以使用它们的签名解包您需要的参数。例如,请参见简单 PrinterCallback 的代码。

示例

class PrinterCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            print(logs)

on_epoch_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

一个 epoch 开始时调用的事件。

on_epoch_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

一个 epoch 结束时调用的事件。

on_evaluate

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

评估阶段结束后调用的事件。

on_init_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Trainer 初始化结束时调用的事件。

on_log

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

记录最新日志后调用的事件。

on_optimizer_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

优化器步骤之后但在梯度归零之前调用的事件。用于监控梯度很有用。

on_pre_optimizer_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

优化器步骤之前但在梯度裁剪之后调用的事件。用于监控梯度很有用。

on_predict

< >

( args: TrainingArguments state: TrainerState control: TrainerControl metrics **kwargs )

成功预测后调用的事件。

on_prediction_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

预测步骤后调用的事件。

on_save

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

检查点保存后调用的事件。

on_step_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

训练步骤开始时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。

on_step_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

训练步骤结束时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。

on_substep_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

梯度累积期间一个子步骤结束时调用的事件。

on_train_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

训练开始时调用的事件。

on_train_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

训练结束时调用的事件。

以下是如何向 PyTorch Trainer 注册自定义回调的示例

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback],  # We can either pass the callback class this way or an instance of it (MyCallback())
)

另一种注册回调的方法是调用 trainer.add_callback(),如下所示

trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())

TrainerState

class transformers.TrainerState

< >

( epoch: typing.Optional[float] = None global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 train_batch_size: typing.Optional[int] = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 log_history: list = None best_metric: typing.Optional[float] = None best_global_step: typing.Optional[int] = None best_model_checkpoint: typing.Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True is_hyper_param_search: bool = False trial_name: typing.Optional[str] = None trial_params: dict = None stateful_callbacks: list = None )

参数

  • epoch (float, 可选) — 仅在训练期间设置,表示训练所在的 epoch(小数部分表示当前 epoch 完成的百分比)。
  • global_step (int, 可选, 默认为 0) — 训练期间,表示已完成的更新步骤数。
  • max_steps (int, 可选, 默认为 0) — 当前训练期间要执行的更新步骤数。
  • logging_steps (int, 可选, 默认为 500) — 每 X 更新步骤记录一次日志。
  • eval_steps (int, 可选) — 每 X 步骤运行一次评估。
  • save_steps (int, 可选, 默认为 500) — 每 X 更新步骤保存一次检查点。
  • train_batch_size (int, 可选) — 训练数据加载器的批大小。仅在使用了 auto_find_batch_size 时需要。
  • num_input_tokens_seen (int, 可选, 默认为 0) — 当跟踪输入 token 时,训练期间已看到的 token 数量(输入 token 的数量,而非预测 token 的数量)。
  • total_flos (float, 可选, 默认为 0) — 模型自训练开始以来完成的浮点运算总数(存储为浮点数以避免溢出)。
  • log_history (list[dict[str, float]], 可选) — 自训练开始以来的日志列表。
  • best_metric (float, 可选) — 当跟踪最佳模型时,迄今为止遇到的最佳指标值。
  • best_global_step (int, 可选) — 当跟踪最佳模型时,遇到最佳指标时的步骤。用于设置 best_model_checkpoint
  • best_model_checkpoint (str, 可选) — 当跟踪最佳模型时,迄今为止遇到的最佳模型的检查点名称。
  • is_local_process_zero (bool, 可选, 默认为 True) — 此进程是否为本地(例如,如果以分布式方式在多台机器上训练)主进程。
  • is_world_process_zero (bool, 可选, 默认为 True) — 此进程是否为全局主进程(当以分布式方式在多台机器上训练时,此参数仅对一个进程为 True)。
  • is_hyper_param_search (bool, 可选, 默认为 False) — 是否正在使用 Trainer.hyperparameter_search 进行超参数搜索。这将影响数据在 TensorBoard 中的记录方式。
  • stateful_callbacks (list[StatefulTrainerCallback], 可选) — 附加到 Trainer 的回调,其状态应被保存或恢复。相关回调应实现 statefrom_state 函数。

一个包含 Trainer 内部状态的类,该状态将在检查点时与模型和优化器一起保存,并传递给 TrainerCallback

在本类中,一步被理解为一次更新步。当使用梯度累积时,一次更新步可能需要多次正向和反向传播:如果您使用 gradient_accumulation_steps=n,那么一次更新步需要经过 *n* 个批次。

compute_steps

< >

( args max_steps )

根据是否为比例来计算并存储用于日志记录、评估和保存步骤的绝对值。

init_training_references

< >

( trainer max_steps num_train_epochs trial )

存储 self 中所需的初始训练引用

load_from_json

< >

( json_path: str )

json_path 的内容创建实例。

save_to_json

< >

( json_path: str )

以 JSON 格式将此实例的内容保存在 json_path 中。

TrainerControl

transformers.TrainerControl

< >

( should_training_stop: bool = False should_epoch_stop: bool = False should_save: bool = False should_evaluate: bool = False should_log: bool = False )

参数

  • should_training_stop (bool, 可选, 默认为 False) — 训练是否应该中断。

    如果为 True,此变量将不会被重置为 False。训练将直接停止。

  • should_epoch_stop (bool, 可选, 默认为 False) — 当前 epoch 是否应该中断。

    如果为 True,此变量将在下一个 epoch 开始时被重置为 False

  • should_save (bool, 可选, 默认为 False) — 模型是否应该在此步保存。

    如果为 True,此变量将在下一步开始时被重置为 False

  • should_evaluate (bool, 可选, 默认为 False) — 模型是否应该在此步评估。

    如果为 True,此变量将在下一步开始时被重置为 False

  • should_log (bool, 可选, 默认为 False) — 日志是否应该在此步报告。

    如果为 True,此变量将在下一步开始时被重置为 False

一个处理 Trainer 控制流的类。此类由 TrainerCallback 使用,以激活训练循环中的某些开关。

< > 在 GitHub 上更新