SetFit 文档

回调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

回调

SetFit 模型可以通过回调函数(例如用于日志记录或提前停止)进行影响。

本指南将向您展示它们是什么以及如何使用它们。

SetFit 中的回调

回调是自定义 SetFit 训练器 中训练循环行为的对象,它们可以检查训练循环状态(用于进度报告、日志记录、训练期间检查嵌入)并做出决策(例如提前停止)。

特别是,训练器 使用一个 TrainerControl,它可以受到回调的影响来停止训练、保存模型、评估或记录,以及一个 TrainerState,它在训练期间跟踪一些训练循环指标,例如到目前为止的训练步数。

SetFit 依赖于 transformers 中实现的回调,如 transformers 文档 此处 所述。

默认回调

SetFit 使用 TrainingArguments.report_to 参数来指定应启用哪些内置回调。此参数默认为 "all",这意味着将启用所有已安装的 transformers 中的第三方回调。例如,TensorBoardCallbackWandbCallback

除此之外,PrinterCallbackProgressCallback 始终启用以显示训练进度,DefaultFlowCallback 也始终启用以正确更新 TrainerControl

使用回调

如前所述,您可以使用 TrainingArguments.report_to 来精确指定您希望启用哪些回调。例如:

from setfit import TrainingArguments

args = TrainingArguments(
    ...,
    report_to="wandb",
    ...,
)
# or 
args = TrainingArguments(
    ...,
    report_to=["wandb", "tensorboard"],
    ...,
)

您还可以使用 Trainer.add_callback()Trainer.pop_callback()Trainer.remove_callback() 来影响训练器回调,您还可以通过 训练器 初始化来指定回调,例如:

from setfit import Trainer

...

trainer = Trainer(
    model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()

自定义回调

SetFit 支持自定义回调,其方式与 transformers 相同:通过子类化 TrainerCallback。此类别实现了许多可以重写的 on_... 方法。例如,以下脚本展示了一个自定义回调,它在训练期间保存训练和评估嵌入的 tSNE 绘图。

import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class EmbeddingPlotCallback(TrainerCallback):
    """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        os.makedirs("logs", exist_ok=True)

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
        train_embeddings = model.encode(train_dataset["text"])
        eval_embeddings = model.encode(eval_dataset["text"])

        fig, (train_ax, eval_ax) = plt.subplots(ncols=2)

        train_X = TSNE(n_components=2).fit_transform(train_embeddings)
        train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
        train_ax.set_title("Training embeddings")

        eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
        eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
        eval_ax.set_title("Evaluation embeddings")

        fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
        fig.savefig(f"logs/step_{state.global_step}.png")

,其中

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EmbeddingPlotCallback()]
)
trainer.train()

来自 EmbeddingPlotCallbackon_evaluate 将在每次评估调用时触发。在本例中,它产生了以下绘图:

第 20 步 第 40 步
step_20 step_40
第 60 步 第 80 步
step_60 step_80
< > 在 GitHub 上更新