Transformers 文档
回调
并获得增强的文档体验
开始使用
回调函数
回调函数是能够自定义 PyTorch Trainer 中训练循环行为的对象(此功能尚未在 TensorFlow 中实现),它可以检查训练循环状态(用于进度报告、在 TensorBoard 或其他 ML 平台上记录日志……)并做出决策(如提前停止)。
回调函数是“只读”的代码片段,除了它们返回的 TrainerControl 对象外,它们不能改变训练循环中的任何内容。对于需要改变训练循环的自定义,您应该子类化 Trainer 并覆盖您需要的方法(参阅 训练器 以获取示例)。
默认情况下,`TrainingArguments.report_to` 设置为 `"all"`,因此 Trainer 将使用以下回调函数。
- DefaultFlowCallback,它处理日志记录、保存和评估的默认行为。
- PrinterCallback 或 ProgressCallback 用于显示进度并打印日志(如果您通过 TrainingArguments 禁用 tqdm,则使用第一个,否则使用第二个)。
- 如果 TensorBoard 可访问(通过 PyTorch >= 1.4 或 tensorboardX),则使用 TensorBoardCallback。
- 如果安装了 wandb,则使用 WandbCallback。
- 如果安装了 comet_ml,则使用 CometCallback。
- 如果安装了 mlflow,则使用 MLflowCallback。
- 如果安装了 neptune,则使用 NeptuneCallback。
- 如果安装了 azureml-sdk,则使用 AzureMLCallback。
- 如果安装了 codecarbon,则使用 CodeCarbonCallback。
- 如果安装了 clearml,则使用 ClearMLCallback。
- 如果安装了 dagshub,则使用 DagsHubCallback。
- 如果安装了 flyte,则使用 FlyteCallback。
- 如果安装了 dvclive,则使用 DVCLiveCallback。
- 如果安装了 swanlab,则使用 SwanLabCallback。
如果已安装某个包,但您不想使用随附的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您要使用的集成的列表(例如,`["azure_ml", "wandb"]`)。
实现回调函数的主要类是 TrainerCallback。它获取用于实例化 Trainer 的 TrainingArguments,可以通过 TrainerState 访问该 Trainer 的内部状态,并且可以通过 TrainerControl 对训练循环执行一些操作。
可用回调函数
以下是库中可用的 TrainerCallback 列表
一个将日志发送到 Comet ML 的 TrainerCallback。
设置可选的 Comet 集成。
环境
- COMET_MODE (
str
, 可选,默认为get_or_create
):控制是创建并记录到新的 Comet 实验还是附加到现有实验。它接受以下值get_or_create
:根据是否设置了 `COMET_EXPERIMENT_KEY` 以及该键的实验是否存在自动决定。create
:始终创建一个新的 Comet 实验。get
:始终尝试附加到现有 Comet 实验。需要设置 `COMET_EXPERIMENT_KEY`。ONLINE
:**已弃用**,用于创建在线实验。请改用 `COMET_START_ONLINE=1`。OFFLINE
:**已弃用**,用于创建离线实验。请改用 `COMET_START_ONLINE=0`。DISABLED
:**已弃用**,用于禁用 Comet 日志记录。请改用 `--report_to` 标志来控制用于记录结果的集成。
- COMET_PROJECT_NAME (
str
, 可选):Comet 实验的项目名称。 - COMET_LOG_ASSETS (
str
, 可选,默认为TRUE
):是否将训练资产(tf 事件日志、检查点等)记录到 Comet。可以是TRUE
或FALSE
。
有关环境中可配置项的数量,请参阅此处。
一个 TrainerCallback,用于处理日志、评估和检查点的默认训练循环流。
一个只打印日志的 TrainerCallback。
一个 TrainerCallback,显示训练或评估的进度。您可以修改 `max_str_len` 来控制日志记录时字符串截断的长度。
class transformers.EarlyStoppingCallback
< 源 >( early_stopping_patience: int = 1 early_stopping_threshold: typing.Optional[float] = 0.0 )
一个处理提前停止的 TrainerCallback。
此回调函数依赖于 TrainingArguments 参数 *load_best_model_at_end* 功能,用于设置 TrainerState 中的 best_metric。请注意,如果 TrainingArguments 参数 *save_steps* 与 *eval_steps* 不同,则提前停止不会发生,直到下一个保存步骤。
class transformers.integrations.TensorBoardCallback
< 源 >( tb_writer = None )
一个将日志发送到 TensorBoard 的 TrainerCallback。
一个将指标、媒体、模型检查点记录到 Weight and Biases 的 TrainerCallback。
设置可选的 Weights & Biases (wandb) 集成。
可以子类化并重写此方法以根据需要自定义设置。欲了解更多信息,请参阅此处。您还可以覆盖以下环境变量
环境
WANDB_LOG_MODEL (
str
, 可选, 默认为"false"
):是否在训练期间记录模型和检查点。可以是"end"
,"checkpoint"
或"false"
。如果设置为"end"
,模型将在训练结束时上传。如果设置为"checkpoint"
,检查点将每args.save_steps
上传一次。如果设置为"false"
,模型将不会上传。与load_best_model_at_end()
一起使用以上传最佳模型。5.0 版本中已弃用
在 🤗 Transformers 的 5.0 版本中,将
WANDB_LOG_MODEL
设置为bool
将被弃用。WANDB_WATCH (
str
, 可选,默认为"false"
):可以是"gradients"
,"all"
,"parameters"
或"false"
。设置为"all"
以记录梯度和参数。WANDB_PROJECT (
str
, 可选,默认为"huggingface"
):将其设置为自定义字符串以将结果存储在不同的项目中。WANDB_DISABLED (
bool
, 可选,默认为False
):是否完全禁用 wandb。设置为 `WANDB_DISABLED=true` 以禁用。
一个将日志发送到 MLflow 的 TrainerCallback。可以通过设置环境变量 `DISABLE_MLFLOW_INTEGRATION = TRUE` 来禁用。
设置可选的 MLflow 集成。
环境
- HF_MLFLOW_LOG_ARTIFACTS (
str
, 可选):是否使用 MLflow `log_artifact()` 功能记录工件。这仅在记录到远程服务器(例如 s3 或 GCS)时才有意义。如果设置为 `True` 或 *1*,则在 TrainingArguments 的 `output_dir` 中每次保存时将每个保存的检查点复制到本地或远程工件存储。在没有远程存储的情况下使用它只会将文件复制到您的工件位置。 - MLFLOW_TRACKING_URI (
str
, 可选):是否将运行存储在特定路径或远程服务器。默认情况下未设置,这将完全跳过设置跟踪 URI。 - MLFLOW_EXPERIMENT_NAME (
str
, 可选, 默认为None
):是否使用 MLflow 实验名称来启动运行。默认为None
,这将指向 MLflow 中的 `Default` 实验。否则,它是要激活的实验的区分大小写名称。如果不存在具有此名称的实验,则会创建一个具有此名称的新实验。 - MLFLOW_TAGS (
str
, 可选):键/值对字典的字符串转储,将作为标签添加到 MLflow 运行中。示例:`os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'`。 - MLFLOW_NESTED_RUN (
str
, 可选):是否使用 MLflow 嵌套运行。如果设置为True
或 *1*,将在当前运行中创建一个嵌套运行。 - MLFLOW_RUN_ID (
str
, 可选):允许重新连接到现有运行,这在从检查点恢复训练时很有用。当设置了 `MLFLOW_RUN_ID` 环境变量时,`start_run` 尝试恢复具有指定运行 ID 的运行,其他参数将被忽略。 - MLFLOW_FLATTEN_PARAMS (
str
, 可选,默认为False
):是否在记录之前展平参数字典。 - MLFLOW_MAX_LOG_PARAMS (
int
, 可选):设置在运行中记录的最大参数数量。
一个将日志发送到 AzureML 的 TrainerCallback。
一个追踪训练过程中二氧化碳排放量的 TrainerCallback。
class transformers.integrations.NeptuneCallback
< 源 >( api_token: typing.Optional[str] = None project: typing.Optional[str] = None name: typing.Optional[str] = None base_namespace: str = 'finetuning' run = None log_parameters: bool = True log_checkpoints: typing.Optional[str] = None **neptune_run_kwargs )
参数
- api_token (
str
, 可选) — 注册时获取的 Neptune API 令牌。如果您已将令牌保存到 `NEPTUNE_API_TOKEN` 环境变量中(强烈推荐),则可以省略此参数。请参阅文档中的完整设置说明。 - project (
str
, 可选) — 现有 Neptune 项目的名称,格式为“workspace-name/project-name”。您可以在 Neptune 的项目设置 -> 属性中找到并复制该名称。如果为 None(默认),则使用 `NEPTUNE_PROJECT` 环境变量的值。 - name (
str
, 可选) — 运行的自定义名称。 - base_namespace (
str
, 可选, 默认为“finetuning”) — 在 Neptune 运行中,将包含回调函数记录的所有元数据的根命名空间。 - log_parameters (
bool
, 可选, 默认为True
) — 如果为 True,则记录 Trainer 提供的所有 Trainer 参数和模型参数。 - log_checkpoints (
str
, 可选) — 如果为“same”,则在 Trainer 保存检查点时上传检查点。如果为“last”,则仅上传最近保存的检查点。如果为“best”,则上传最佳检查点(Trainer 保存的检查点中)。如果为 `None`,则不上传检查点。 - run (
Run
, 可选) — 如果您想继续记录到现有运行,请传入一个 Neptune 运行对象。有关恢复运行的更多信息,请参阅文档。 - **neptune_run_kwargs (可选) — 在创建新运行时直接传递给
neptune.init_run()
函数的其他关键字参数。
将日志发送到 Neptune 的 TrainerCallback。
有关说明和示例,请参阅 Neptune 文档中的Transformers 集成指南。
一个将日志发送到 ClearML 的 TrainerCallback。
环境
- CLEARML_PROJECT (
str
, 可选,默认为HuggingFace Transformers
):ClearML 项目名称。 - CLEARML_TASK (
str
, 可选,默认为Trainer
):ClearML 任务名称。 - CLEARML_LOG_MODEL (
bool
, 可选,默认为False
):是否在训练期间将模型记录为工件。
一个将日志记录到 DagsHub 的 TrainerCallback。继承自 `MLflowCallback`
设置 DagsHub 的日志集成。
环境
- HF_DAGSHUB_LOG_ARTIFACTS (
str
, 可选):是否保存实验的数据和模型工件。默认为 `False`。
class transformers.integrations.FlyteCallback
< 源 >( save_log_history: bool = True sync_checkpoints: bool = True )
一个将日志发送到 Flyte 的 TrainerCallback。注意:此回调函数仅在 Flyte 任务中有效。
class transformers.integrations.DVCLiveCallback
< 源 >( live: typing.Optional[typing.Any] = None log_model: typing.Union[typing.Literal['all'], bool, NoneType] = None **kwargs )
参数
- live (
dvclive.Live
, 可选, 默认为None
) — 可选的 Live 实例。如果为 None,将使用 **kwargs 创建一个新实例。 - log_model (Union[Literal[“all”], bool], 可选, 默认为
None
) — 是否使用dvclive.Live.log_artifact()
记录由 Trainer 创建的检查点。如果设置为True
,最终检查点将在训练结束时记录。如果设置为"all"
,整个 TrainingArguments 的 `output_dir` 将在每个检查点记录。
一个将日志发送到 DVCLive 的 TrainerCallback。
在 `setup` 中使用以下环境变量配置集成。要在此环境变量之外自定义此回调函数,请参阅此处。
设置可选的 DVCLive 集成。要在此环境变量之外自定义此回调函数,请参阅此处。
环境
- HF_DVCLIVE_LOG_MODEL (
str
, 可选):是否使用 `dvclive.Live.log_artifact()` 记录由 Trainer 创建的检查点。如果设置为 `True` 或 *1*,最终检查点将在训练结束时记录。如果设置为 `all`,整个 TrainingArguments 的 `output_dir` 将在每个检查点记录。
一个将指标、媒体、模型检查点记录到 SwanLab 的 TrainerCallback。
设置可选的 SwanLab (swanlab) 集成。
如有需要,可以子类化并覆盖此方法以自定义设置。更多信息请参阅此处。
您还可以覆盖以下环境变量。更多关于环境变量的信息请参阅此处
环境
SWANLAB_API_KEY (
str
, 可选, 默认为None
): 云API密钥。登录时,首先检查此环境变量。如果不存在,系统会检查用户是否已登录。如果未登录,则启动登录过程。- 如果将字符串传递给登录接口,则忽略此环境变量。
- 如果用户已登录,此环境变量优先于本地存储的登录信息。
SWANLAB_PROJECT (
str
, 可选, 默认为None
): 将此设置为自定义字符串,以便将结果存储在不同的项目中。如果未指定,则使用当前运行目录的名称。SWANLAB_LOG_DIR (
str
, 可选, 默认为swanlog
): 此环境变量指定在本地模式下运行时的日志文件存储路径。默认情况下,日志保存在工作目录下名为 swanlog 的文件夹中。SWANLAB_MODE (
Literal["local", "cloud", "disabled"]
, 可选, 默认为cloud
): SwanLab 的解析模式,涉及操作员注册的回调。目前有三种模式:本地、云和禁用。注意:区分大小写。更多信息请参阅此处SWANLAB_LOG_MODEL (
str
, 可选, 默认为None
): SwanLab 目前不支持保存模式功能。此功能将在未来版本中提供。SWANLAB_WEB_HOST (
str
, 可选, 默认为None
): 私有版本 SwanLab 云环境的 Web 地址(免费)SWANLAB_API_HOST (
str
, 可选, 默认为None
): 私有版本 SwanLab 云环境的 API 地址(免费)
TrainerCallback
class transformers.TrainerCallback
< source >( )
参数
- args (TrainingArguments) — 用于实例化 Trainer 的训练参数。
- state (TrainerState) — Trainer 的当前状态。
- control (TrainerControl) — 返回给 Trainer 的对象,可用于做出某些决策。
- model (PreTrainedModel 或
torch.nn.Module
) — 正在训练的模型。 - tokenizer (PreTrainedTokenizer) — 用于数据编码的分词器。此参数已被弃用,建议使用
processing_class
。 - processing_class ([
PreTrainedTokenizer
或BaseImageProcessor
或ProcessorMixin
或FeatureExtractionMixin
]) — 用于数据编码的处理类。可以是分词器、处理器、图像处理器或特征提取器。 - optimizer (
torch.optim.Optimizer
) — 用于训练步骤的优化器。 - lr_scheduler (
torch.optim.lr_scheduler.LambdaLR
) — 用于设置学习率的调度器。 - train_dataloader (
torch.utils.data.DataLoader
, 可选) — 当前用于训练的数据加载器。 - eval_dataloader (
torch.utils.data.DataLoader
, 可选) — 当前用于评估的数据加载器。 - metrics (
dict[str, float]
) — 上一评估阶段计算的指标。这些仅在
on_evaluate
事件中可访问。 - logs (
dict[str, float]
) — 要记录的值。这些仅在
on_log
事件中可访问。
一个类,用于在某些事件中检查训练循环状态并做出决策的对象。在每个这些事件中,以下参数可用:
control
对象是唯一可以被回调更改的对象,在这种情况下,更改它的事件应返回修改后的版本。
参数 args
、state
和 control
是所有事件的位置参数,所有其他参数都分组在 kwargs
中。您可以使用它们的签名解包您需要的参数。例如,请参见简单 PrinterCallback 的代码。
示例
class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
on_epoch_begin
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
一个 epoch 开始时调用的事件。
on_epoch_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
一个 epoch 结束时调用的事件。
on_evaluate
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
评估阶段结束后调用的事件。
on_init_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
Trainer 初始化结束时调用的事件。
记录最新日志后调用的事件。
on_optimizer_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
优化器步骤之后但在梯度归零之前调用的事件。用于监控梯度很有用。
on_pre_optimizer_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
优化器步骤之前但在梯度裁剪之后调用的事件。用于监控梯度很有用。
on_predict
< source >( args: TrainingArguments state: TrainerState control: TrainerControl metrics **kwargs )
成功预测后调用的事件。
on_prediction_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
预测步骤后调用的事件。
检查点保存后调用的事件。
on_step_begin
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
训练步骤开始时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。
on_step_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
训练步骤结束时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。
on_substep_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
梯度累积期间一个子步骤结束时调用的事件。
on_train_begin
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
训练开始时调用的事件。
on_train_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
训练结束时调用的事件。
以下是如何向 PyTorch Trainer 注册自定义回调的示例
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback())
)
另一种注册回调的方法是调用 trainer.add_callback()
,如下所示
trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())
TrainerState
class transformers.TrainerState
< source >( epoch: typing.Optional[float] = None global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 train_batch_size: typing.Optional[int] = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 log_history: list = None best_metric: typing.Optional[float] = None best_global_step: typing.Optional[int] = None best_model_checkpoint: typing.Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True is_hyper_param_search: bool = False trial_name: typing.Optional[str] = None trial_params: dict = None stateful_callbacks: list = None )
参数
- epoch (
float
, 可选) — 仅在训练期间设置,表示训练所在的 epoch(小数部分表示当前 epoch 完成的百分比)。 - global_step (
int
, 可选, 默认为 0) — 训练期间,表示已完成的更新步骤数。 - max_steps (
int
, 可选, 默认为 0) — 当前训练期间要执行的更新步骤数。 - logging_steps (
int
, 可选, 默认为 500) — 每 X 更新步骤记录一次日志。 - eval_steps (
int
, 可选) — 每 X 步骤运行一次评估。 - save_steps (
int
, 可选, 默认为 500) — 每 X 更新步骤保存一次检查点。 - train_batch_size (
int
, 可选) — 训练数据加载器的批大小。仅在使用了auto_find_batch_size
时需要。 - num_input_tokens_seen (
int
, 可选, 默认为 0) — 当跟踪输入 token 时,训练期间已看到的 token 数量(输入 token 的数量,而非预测 token 的数量)。 - total_flos (
float
, 可选, 默认为 0) — 模型自训练开始以来完成的浮点运算总数(存储为浮点数以避免溢出)。 - log_history (
list[dict[str, float]]
, 可选) — 自训练开始以来的日志列表。 - best_metric (
float
, 可选) — 当跟踪最佳模型时,迄今为止遇到的最佳指标值。 - best_global_step (
int
, 可选) — 当跟踪最佳模型时,遇到最佳指标时的步骤。用于设置best_model_checkpoint
。 - best_model_checkpoint (
str
, 可选) — 当跟踪最佳模型时,迄今为止遇到的最佳模型的检查点名称。 - is_local_process_zero (
bool
, 可选, 默认为True
) — 此进程是否为本地(例如,如果以分布式方式在多台机器上训练)主进程。 - is_world_process_zero (
bool
, 可选, 默认为True
) — 此进程是否为全局主进程(当以分布式方式在多台机器上训练时,此参数仅对一个进程为True
)。 - is_hyper_param_search (
bool
, 可选, 默认为False
) — 是否正在使用 Trainer.hyperparameter_search 进行超参数搜索。这将影响数据在 TensorBoard 中的记录方式。 - stateful_callbacks (
list[StatefulTrainerCallback]
, 可选) — 附加到Trainer
的回调,其状态应被保存或恢复。相关回调应实现state
和from_state
函数。
一个包含 Trainer 内部状态的类,该状态将在检查点时与模型和优化器一起保存,并传递给 TrainerCallback。
在本类中,一步被理解为一次更新步。当使用梯度累积时,一次更新步可能需要多次正向和反向传播:如果您使用 gradient_accumulation_steps=n
,那么一次更新步需要经过 *n* 个批次。
根据是否为比例来计算并存储用于日志记录、评估和保存步骤的绝对值。
存储 self
中所需的初始训练引用
从 json_path
的内容创建实例。
以 JSON 格式将此实例的内容保存在 json_path
中。
TrainerControl
类 transformers.TrainerControl
< 源 >( should_training_stop: bool = False should_epoch_stop: bool = False should_save: bool = False should_evaluate: bool = False should_log: bool = False )
参数
- should_training_stop (
bool
, 可选, 默认为False
) — 训练是否应该中断。如果为
True
,此变量将不会被重置为False
。训练将直接停止。 - should_epoch_stop (
bool
, 可选, 默认为False
) — 当前 epoch 是否应该中断。如果为
True
,此变量将在下一个 epoch 开始时被重置为False
。 - should_save (
bool
, 可选, 默认为False
) — 模型是否应该在此步保存。如果为
True
,此变量将在下一步开始时被重置为False
。 - should_evaluate (
bool
, 可选, 默认为False
) — 模型是否应该在此步评估。如果为
True
,此变量将在下一步开始时被重置为False
。 - should_log (
bool
, 可选, 默认为False
) — 日志是否应该在此步报告。如果为
True
,此变量将在下一步开始时被重置为False
。
一个处理 Trainer 控制流的类。此类由 TrainerCallback 使用,以激活训练循环中的某些开关。