SetFit 文档

回调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

回调

SetFit 模型可以被回调影响,例如用于日志记录或提前停止。

本指南将向您展示它们是什么以及如何使用它们。

SetFit 中的回调

回调是自定义 SetFit Trainer 中训练循环行为的对象,可以检查训练循环状态(用于进度报告、日志记录、训练期间检查嵌入)并做出决策(例如提前停止)。

特别是,Trainer 使用 TrainerControl,可以通过回调来影响它以停止训练、保存模型、评估或记录,以及 TrainerState,它在训练期间跟踪一些训练循环指标,例如到目前为止的训练步数。

SetFit 依赖于 transformers 中实现的回调,如 transformers 文档 此处 所述。

默认回调

SetFit 使用 TrainingArguments.report_to 参数来指定应启用哪些内置回调。此参数默认为 "all",这意味着还将启用来自 transformers 的所有已安装的第三方回调。例如 TensorBoardCallbackWandbCallback

除此之外,PrinterCallbackProgressCallback 始终启用以显示训练进度,并且 DefaultFlowCallback 也始终启用以正确更新 TrainerControl

使用回调

如前所述,您可以使用 TrainingArguments.report_to 来精确指定您想要启用的回调。例如

from setfit import TrainingArguments

args = TrainingArguments(
    ...,
    report_to="wandb",
    ...,
)
# or 
args = TrainingArguments(
    ...,
    report_to=["wandb", "tensorboard"],
    ...,
)

您还可以使用 Trainer.add_callback()Trainer.pop_callback()Trainer.remove_callback() 来影响 trainer 回调,您可以通过 Trainer 初始化来指定回调,例如:

from setfit import Trainer

...

trainer = Trainer(
    model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()

自定义回调

SetFit 以与 transformers 相同的方式支持自定义回调:通过子类化 TrainerCallback。此类实现了许多可以被重写的 on_... 方法。例如,以下脚本显示了一个自定义回调,该回调在训练期间保存训练和评估嵌入的 tSNE 图。

import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class EmbeddingPlotCallback(TrainerCallback):
    """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        os.makedirs("logs", exist_ok=True)

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
        train_embeddings = model.encode(train_dataset["text"])
        eval_embeddings = model.encode(eval_dataset["text"])

        fig, (train_ax, eval_ax) = plt.subplots(ncols=2)

        train_X = TSNE(n_components=2).fit_transform(train_embeddings)
        train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
        train_ax.set_title("Training embeddings")

        eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
        eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
        eval_ax.set_title("Evaluation embeddings")

        fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
        fig.savefig(f"logs/step_{state.global_step}.png")

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EmbeddingPlotCallback()]
)
trainer.train()

EmbeddingPlotCallback 中的 on_evaluate 将在每次评估调用时触发。在本例中,它导致绘制了以下图形

步骤 20 步骤 40
step_20 step_40
步骤 60 步骤 80
step_60 step_80
< > 在 GitHub 上更新