Transformers 文档

回调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

回调

回调是可以在 PyTorch Trainer 中自定义训练循环行为的对象(此功能尚未在 TensorFlow 中实现),它可以检查训练循环状态(用于进度报告、在 TensorBoard 或其他 ML 平台上记录日志……)并做出决策(如提前停止)。

除了它们返回的 TrainerControl 对象之外,回调是“只读”代码片段,它们无法更改训练循环中的任何内容。对于需要在训练循环中进行更改的自定义,您应该继承 Trainer 并覆盖您需要的方法(请参阅 trainer 获取示例)。

默认情况下,TrainingArguments.report_to 设置为 "all",因此 Trainer 将使用以下回调。

如果安装了某个包,但您不希望使用随附的集成,则可以将 TrainingArguments.report_to 更改为仅包含您想要使用的集成的列表(例如,["azure_ml", "wandb"])。

实现回调的主类是 TrainerCallback。它获取用于实例化 TrainerTrainingArguments,可以通过 TrainerState 访问该 Trainer 的内部状态,并且可以通过 TrainerControl 对训练循环执行某些操作。

可用的回调

以下是库中可用的 TrainerCallback 列表

class transformers.integrations.CometCallback

< >

( )

一个 TrainerCallback,用于将日志发送到 Comet ML

设置

< >

( args state model )

设置可选的 Comet 集成。

环境

  • COMET_MODE (str, 可选, 默认为 get_or_create):控制是创建新的 Comet 实验并记录日志,还是附加到现有实验。它接受以下值
    • get_or_create:根据是否设置了 COMET_EXPERIMENT_KEY 以及是否已存在具有该键的实验来自动决定。
    • create:始终创建新的 Comet 实验。
    • get:始终尝试附加到现有 Comet 实验。需要设置 COMET_EXPERIMENT_KEY
    • ONLINE已弃用,用于创建在线实验。请改用 COMET_START_ONLINE=1
    • OFFLINE已弃用,用于创建离线实验。请改用 COMET_START_ONLINE=0
    • DISABLED已弃用,用于禁用 Comet 日志记录。请改用 --report_to 标志来控制用于记录结果的集成。
  • COMET_PROJECT_NAME (str, 可选):Comet 项目实验名称。
  • COMET_LOG_ASSETS (str, 可选, 默认为 TRUE):是否将训练资产(tf 事件日志、检查点等)记录到 Comet。可以是 TRUEFALSE

有关环境中许多可配置项的信息,请参阅此处

class transformers.DefaultFlowCallback

< >

( )

一个 TrainerCallback,用于处理训练循环中日志、评估和检查点的默认流程。

class transformers.PrinterCallback

< >

( )

一个简单的 TrainerCallback,仅打印日志。

class transformers.ProgressCallback

< >

( max_str_len: int = 100 )

一个 TrainerCallback,用于显示训练或评估的进度。您可以修改 max_str_len 以控制记录日志时字符串被截断的长度。

class transformers.EarlyStoppingCallback

< >

( early_stopping_patience: int = 1 early_stopping_threshold: typing.Optional[float] = 0.0 )

参数

  • early_stopping_patience (int) — 与 metric_for_best_model 一起使用,以便在指定的指标在 early_stopping_patience 次评估调用中变差时停止训练。
  • early_stopping_threshold(float, 可选) — 与 TrainingArguments metric_for_best_modelearly_stopping_patience 一起使用,以表示指定的指标必须改进多少才能满足提前停止条件。 `

一个 TrainerCallback,用于处理提前停止。

此回调依赖于 TrainingArguments 参数 load_best_model_at_end 功能,以在 TrainerState 中设置 best_metric。请注意,如果 TrainingArguments 参数 save_stepseval_steps 不同,则提前停止将不会发生,直到下一个保存步骤。

class transformers.integrations.TensorBoardCallback

< >

( tb_writer = None )

参数

  • tb_writer (SummaryWriter, 可选) — 要使用的 writer。如果未设置,将实例化一个。

一个 TrainerCallback,用于将日志发送到 TensorBoard

class transformers.integrations.WandbCallback

< >

( )

一个 TrainerCallback,用于将指标、媒体、模型检查点记录到 Weight and Biases

设置

< >

( args state model **kwargs )

设置可选的 Weights & Biases (wandb) 集成。

可以子类化并重写此方法以自定义设置(如果需要)。 更多信息请见 此处。 您还可以覆盖以下环境变量

环境

  • WANDB_LOG_MODEL (str, 可选, 默认为 "false"): 是否在训练期间记录模型和检查点。 可以是 "end", "checkpoint""false"。 如果设置为 "end",模型将在训练结束时上传。 如果设置为 "checkpoint",检查点将在每次 args.save_steps 时上传。 如果设置为 "false",模型将不会上传。 与 load_best_model_at_end() 一起使用以上传最佳模型。

    在 5.0 版本中已弃用

    在 🤗 Transformers 的版本 5 中,将 WANDB_LOG_MODEL 设置为 bool 类型将被弃用。

  • WANDB_WATCH (str, 可选 默认为 "false"): 可以是 "gradients", "all", "parameters", 或 "false"。 设置为 "all" 以记录梯度和参数。

  • WANDB_PROJECT (str, 可选, 默认为 "huggingface"): 将其设置为自定义字符串以将结果存储在不同的项目中。

  • WANDB_DISABLED (bool, 可选, 默认为 False): 是否完全禁用 wandb。 设置 WANDB_DISABLED=true 以禁用。

class transformers.integrations.MLflowCallback

< >

( )

一个 TrainerCallback,用于将日志发送到 MLflow。 可以通过设置环境变量 DISABLE_MLFLOW_INTEGRATION = TRUE 来禁用。

设置

< >

( args state model )

设置可选的 MLflow 集成。

环境

  • HF_MLFLOW_LOG_ARTIFACTS (str, 可选): 是否使用 MLflow 的 .log_artifact() 功能来记录工件。 这仅在记录到远程服务器(例如 s3 或 GCS)时才有意义。 如果设置为 True1,则会在每次保存在 TrainingArgumentsoutput_dir 中的检查点时,将每个保存的检查点复制到本地或远程工件存储。 在没有远程存储的情况下使用它只会将文件复制到您的工件位置。
  • MLFLOW_TRACKING_URI (str, 可选): 是否将运行存储在特定路径或远程服务器上。 默认未设置,这将完全跳过设置跟踪 URI。
  • MLFLOW_EXPERIMENT_NAME (str, 可选, 默认为 None): 是否使用 MLflow experiment_name,在该 experiment_name 下启动运行。 默认为 None,这将指向 MLflow 中的 Default 实验。 否则,它是要激活的实验的区分大小写的名称。 如果具有此名称的实验不存在,则会创建一个具有此名称的新实验。
  • MLFLOW_TAGS (str, 可选): 要添加到 MLflow 运行中作为标签的键/值对字典的字符串转储。 示例: os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
  • MLFLOW_NESTED_RUN (str, 可选): 是否使用 MLflow 嵌套运行。 如果设置为 True1,将在当前运行内部创建一个嵌套运行。
  • MLFLOW_RUN_ID (str, 可选): 允许重新连接到现有运行,这在从检查点恢复训练时非常有用。 当设置 MLFLOW_RUN_ID 环境变量时,start_run 会尝试恢复具有指定运行 ID 的运行,并且其他参数将被忽略。
  • MLFLOW_FLATTEN_PARAMS (str, 可选, 默认为 False): 是否在记录之前展平参数字典。
  • MLFLOW_MAX_LOG_PARAMS (int, 可选): 设置运行中要记录的最大参数数量。

class transformers.integrations.AzureMLCallback

< >

( azureml_run = None )

一个 TrainerCallback,用于将日志发送到 AzureML

class transformers.integrations.CodeCarbonCallback

< >

( )

一个 TrainerCallback,用于跟踪训练的二氧化碳排放量。

class transformers.integrations.NeptuneCallback

< >

( api_token: typing.Optional[str] = None project: typing.Optional[str] = None name: typing.Optional[str] = None base_namespace: str = 'finetuning' run = None log_parameters: bool = True log_checkpoints: typing.Optional[str] = None **neptune_run_kwargs )

参数

  • api_token (str, 可选) — 注册后获得的 Neptune API 令牌。 如果您已将令牌保存到 NEPTUNE_API_TOKEN 环境变量中(强烈推荐),则可以省略此参数。 有关完整的设置说明,请参阅 文档
  • project (str, 可选) — 现有 Neptune 项目的名称,格式为 “workspace-name/project-name”。 您可以在 Neptune 中从项目设置 -> 属性中查找并复制名称。 如果为 None(默认值),则使用 NEPTUNE_PROJECT 环境变量的值。
  • name (str, 可选) — 运行的自定义名称。
  • base_namespace (str, 可选, 默认为 “finetuning”) — 在 Neptune 运行中,将包含回调记录的所有元数据的根命名空间。
  • log_parameters (bool, 可选, 默认为 True) — 如果为 True,则记录 Trainer 提供的所有 Trainer 参数和模型参数。
  • log_checkpoints (str, 可选) — 如果为 “same”,则在 Trainer 保存检查点时上传。 如果为 “last”,则仅上传最近保存的检查点。 如果为 “best”,则上传最佳检查点(在 Trainer 保存的检查点中)。 如果为 None,则不上传检查点。
  • run (Run, 可选) — 如果您想继续记录到现有运行,请传递 Neptune 运行对象。 阅读 文档 中关于恢复运行的更多信息。
  • **neptune_run_kwargs (可选) — 要直接传递给 neptune.init_run() 函数的其他关键字参数,当创建新运行时。

TrainerCallback,用于将日志发送到 Neptune

有关说明和示例,请参阅 Neptune 文档中的 Transformers 集成指南

class transformers.integrations.ClearMLCallback

< >

( )

一个 TrainerCallback,用于将日志发送到 ClearML

环境

  • CLEARML_PROJECT (str, 可选, 默认为 HuggingFace Transformers): ClearML 项目名称。
  • CLEARML_TASK (str, 可选, 默认为 Trainer): ClearML 任务名称。
  • CLEARML_LOG_MODEL (bool, 可选, 默认为 False): 是否在训练期间将模型记录为工件。

class transformers.integrations.DagsHubCallback

< >

( )

一个 TrainerCallback,用于记录到 DagsHub。 扩展了 MLflowCallback

设置

< >

( *args **kwargs )

设置 DagsHub 的日志记录集成。

环境

  • HF_DAGSHUB_LOG_ARTIFACTS (str, 可选): 是否保存实验的数据和模型工件。 默认为 False

class transformers.integrations.FlyteCallback

< >

( save_log_history: bool = True sync_checkpoints: bool = True )

参数

  • sync_checkpoints (bool, 可选, 默认为 True) — 当设置为 True 时,检查点会与 Flyte 同步,并且可以在中断的情况下用于恢复训练。

一个 TrainerCallback,用于将日志发送到 Flyte。 注意:此回调仅在 Flyte 任务中有效。

示例

# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task


@task
def train_hf_transformer():
    cp = current_context().checkpoint
    trainer = Trainer(..., callbacks=[FlyteCallback()])
    output = trainer.train(resume_from_checkpoint=cp.restore())

class transformers.integrations.DVCLiveCallback

< >

( live: typing.Optional[typing.Any] = None log_model: typing.Union[typing.Literal['all'], bool, NoneType] = None **kwargs )

参数

  • live (dvclive.Live, 可选, 默认为 None) — 可选的 Live 实例。如果为 None,将使用 **kwargs 创建新实例。
  • log_model (Union[Literal[“all”], bool], 可选, 默认为 None) — 是否使用 dvclive.Live.log_artifact() 来记录由 Trainer 创建的检查点。如果设置为 True,则在训练结束时记录最终检查点。如果设置为 "all",则在每个检查点记录整个 TrainingArgumentsoutput_dir

一个 TrainerCallback,用于将日志发送到 DVCLive

使用以下环境变量在 setup 中配置集成。 要自定义超出这些环境变量的回调,请参阅 此处

设置

< >

( args state model )

设置可选的 DVCLive 集成。 要自定义超出以下环境变量的回调,请参阅 此处

环境

  • HF_DVCLIVE_LOG_MODEL (str, 可选): 是否使用 dvclive.Live.log_artifact() 来记录由 Trainer 创建的检查点。如果设置为 True1,则在训练结束时记录最终检查点。如果设置为 all,则在每个检查点记录整个 TrainingArgumentsoutput_dir

class transformers.integrations.SwanLabCallback

< >

( )

一个 TrainerCallback,用于将指标、媒体、模型检查点记录到 SwanLab

设置

< >

( args state model **kwargs )

设置可选的 SwanLab (swanlab) 集成。

可以子类化并覆盖此方法以在需要时自定义设置。 更多信息请见 此处

您还可以覆盖以下环境变量。 有关环境变量的更多信息,请参阅 此处

环境

  • SWANLAB_API_KEY (str, 可选, 默认为 None): 云 API 密钥。 登录期间,首先检查此环境变量。 如果它不存在,系统会检查用户是否已登录。 如果没有,则启动登录过程。

    • 如果将字符串传递到登录界面,则此环境变量将被忽略。
    • 如果用户已登录,则此环境变量优先于本地存储的登录信息。
  • SWANLAB_PROJECT (str, 可选, 默认为 None): 将此设置为自定义字符串以在不同的项目中存储结果。 如果未指定,则使用当前运行目录的名称。

  • SWANLAB_LOG_DIR (str, 可选, 默认为 swanlog): 此环境变量指定在本地模式下运行时日志文件的存储路径。 默认情况下,日志保存在工作目录下的名为 swanlog 的文件夹中。

  • SWANLAB_MODE (Literal["local", "cloud", "disabled"], 可选, 默认为 cloud): SwanLab 的解析模式,涉及操作员注册的回调。 目前,有三种模式:local、cloud 和 disabled。 注意:区分大小写。 更多信息请见 此处

  • SWANLAB_LOG_MODEL (str, 可选, 默认为 None): SwanLab 目前不支持保存模式功能。此功能将在未来的版本中提供

  • SWANLAB_WEB_HOST (str, 可选, 默认为 None): SwanLab 云环境的 Web 地址,适用于私有版本(免费)

  • SWANLAB_API_HOST (str, 可选, 默认为 None): SwanLab 云环境的 API 地址,适用于私有版本(免费)

TrainerCallback

class transformers.TrainerCallback

< >

( )

参数

  • args (TrainingArguments) — 用于实例化 Trainer 的训练参数。
  • state (TrainerState) — Trainer 的当前状态。
  • control (TrainerControl) — 返回给 Trainer 并且可以用于做出某些决定的对象。
  • model (PreTrainedModeltorch.nn.Module) — 正在训练的模型。
  • tokenizer (PreTrainedTokenizer) — 用于编码数据的 tokenizer。 此项已弃用,推荐使用 processing_class
  • processing_class ([PreTrainedTokenizerBaseImageProcessorProcessorMixinFeatureExtractionMixin]) — 用于编码数据的 processing class。 可以是 tokenizer、processor、image processor 或 feature extractor。
  • optimizer (torch.optim.Optimizer) — 用于训练步骤的 optimizer。
  • lr_scheduler (torch.optim.lr_scheduler.LambdaLR) — 用于设置学习率的 scheduler。
  • train_dataloader (torch.utils.data.DataLoader, 可选) — 当前用于训练的 dataloader。
  • eval_dataloader (torch.utils.data.DataLoader, 可选) — 当前用于评估的 dataloader。
  • metrics (Dict[str, float]) — 上次评估阶段计算的指标。

    这些指标仅在 on_evaluate 事件中可用。

  • logs (Dict[str, float]) — 要记录的值。

    这些值仅在 on_log 事件中可用。

一个类,用于表示将在某些事件检查训练循环的状态并做出一些决定的对象。 在每个事件中,以下参数可用

control 对象是唯一可以被回调更改的对象,在这种情况下,更改它的事件应返回修改后的版本。

参数 argsstatecontrol 是所有事件的位置参数,所有其他参数都分组在 kwargs 中。 您可以在使用它们的事件签名中解包您需要的参数。 例如,请参阅简单的 PrinterCallback 的代码。

示例

class PrinterCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            print(logs)

on_epoch_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在 epoch 开始时调用的事件。

on_epoch_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在 epoch 结束时调用的事件。

on_evaluate

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在评估阶段之后调用的事件。

on_init_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Trainer 初始化结束时调用的事件。

on_log

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在记录最后日志后调用的事件。

on_optimizer_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在优化器步骤之后但在梯度归零之前调用的事件。 用于监控梯度。

on_pre_optimizer_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在优化器步骤之前但在梯度裁剪之后调用的事件。 用于监控梯度。

on_predict

< >

( args: TrainingArguments state: TrainerState control: TrainerControl metrics **kwargs )

在成功预测后调用的事件。

on_prediction_step

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在预测步骤后调用的事件。

on_save

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在检查点保存后调用的事件。

on_step_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在训练步骤开始时调用的事件。 如果使用梯度累积,则一个训练步骤可能需要多个输入。

on_step_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在训练步骤结束时调用的事件。 如果使用梯度累积,则一个训练步骤可能需要多个输入。

on_substep_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在梯度累积期间的子步骤结束时调用的事件。

on_train_begin

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在训练开始时调用的事件。

on_train_end

< >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

在训练结束时调用的事件。

以下是如何使用 PyTorch Trainer 注册自定义回调的示例

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback],  # We can either pass the callback class this way or an instance of it (MyCallback())
)

注册回调的另一种方法是调用 trainer.add_callback(),如下所示

trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())

TrainerState

class transformers.TrainerState

< >

( epoch: typing.Optional[float] = None global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 train_batch_size: int = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 log_history: list = None best_metric: typing.Optional[float] = None best_global_step: typing.Optional[int] = None best_model_checkpoint: typing.Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True is_hyper_param_search: bool = False trial_name: str = None trial_params: dict = None stateful_callbacks: list = None )

参数

  • epoch (float, 可选) — 仅在训练期间设置,将表示训练所处的 epoch(小数部分是当前 epoch 完成的百分比)。
  • global_step (int, 可选, 默认为 0) — 在训练期间,表示已完成的更新步骤数。
  • max_steps (int, 可选, 默认为 0) — 当前训练期间要执行的更新步骤数。
  • logging_steps (int, 可选, 默认为 500) — 每 X 个更新步骤记录日志
  • eval_steps (int, 可选) — 每 X 步运行一次评估。
  • save_steps (int, 可选, 默认为 500) — 每 X 个更新步骤保存检查点。
  • train_batch_size (int, 可选) — 训练数据加载器的批次大小。 仅在使用了 auto_find_batch_size 时才需要。
  • num_input_tokens_seen (int, 可选, 默认为 0) — 当跟踪输入 tokens 时,训练期间看到的 tokens 数量(输入 tokens 的数量,而不是预测 tokens 的数量)。
  • total_flos (float, 可选, 默认为 0) — 模型自训练开始以来完成的浮点运算总数(存储为浮点数以避免溢出)。
  • log_history (List[Dict[str, float]], 可选) — 自训练开始以来完成的日志列表。
  • best_metric (float, 可选) — 当跟踪最佳模型时,到目前为止遇到的最佳指标值。
  • best_global_step (int, 可选) — 当跟踪最佳模型时,遇到最佳指标的步数。用于设置 best_model_checkpoint
  • best_model_checkpoint (str, 可选) — 当跟踪最佳模型时,到目前为止遇到的最佳模型的检查点名称的值。
  • is_local_process_zero (bool, 可选, 默认为 True) — 此进程是否为本地主进程(例如,如果多台机器以分布式方式进行训练,则在一台机器上)。
  • is_world_process_zero (bool, 可选, 默认为 True) — 此进程是否为全局主进程(当在多台机器上以分布式方式进行训练时,对于只有一个进程,此项将为 True)。
  • is_hyper_param_search (bool, 可选, 默认为 False) — 我们是否正在使用 Trainer.hyperparameter_search 进行超参数搜索。这将影响数据在 TensorBoard 中记录的方式。
  • stateful_callbacks (List[StatefulTrainerCallback], 可选) — 附加到 Trainer 的回调,这些回调的状态应该被保存或恢复。相关的回调应实现 statefrom_state 函数。

一个包含 Trainer 内部状态的类,该状态将在检查点时与模型和优化器一起保存,并传递给 TrainerCallback

在本类中,一个步骤应理解为一个更新步骤。当使用梯度累积时,一个更新步骤可能需要多次前向和后向传递:如果您使用 gradient_accumulation_steps=n,则一个更新步骤需要处理 n 个批次。

compute_steps

< >

( args max_steps )

根据日志记录、评估和保存步骤是否为比例值,计算并存储其绝对值。

init_training_references

< >

( trainer max_steps num_train_epochs trial )

存储 self 中所需的初始训练引用

load_from_json

< >

( json_path: str )

json_path 的内容创建一个实例。

save_to_json

< >

( json_path: str )

将此实例的内容以 JSON 格式保存在 json_path 中。

TrainerControl

class transformers.TrainerControl

< >

( should_training_stop: bool = False should_epoch_stop: bool = False should_save: bool = False should_evaluate: bool = False should_log: bool = False )

参数

  • should_training_stop (bool, 可选, 默认为 False) — 是否应中断训练。

    如果为 True,则此变量不会设置回 False。训练将直接停止。

  • should_epoch_stop (bool, 可选, 默认为 False) — 是否应中断当前 epoch。

    如果为 True,则此变量将在下一个 epoch 开始时设置回 False

  • should_save (bool, 可选, 默认为 False) — 是否应在此步骤保存模型。

    如果为 True,则此变量将在下一步骤开始时设置回 False

  • should_evaluate (bool, 可选, 默认为 False) — 是否应在此步骤评估模型。

    如果为 True,则此变量将在下一步骤开始时设置回 False

  • should_log (bool, 可选, 默认为 False) — 是否应在此步骤报告日志。

    如果为 True,则此变量将在下一步骤开始时设置回 False

一个处理 Trainer 控制流的类。此类由 TrainerCallback 使用,以激活训练循环中的某些开关。

< > 在 GitHub 上更新