SetFit 文档
回调
并获得增强的文档体验
开始使用
回调
SetFit 模型可以被回调影响,例如用于日志记录或提前停止。
本指南将向您展示它们是什么以及如何使用它们。
SetFit 中的回调
回调是自定义 SetFit Trainer 中训练循环行为的对象,可以检查训练循环状态(用于进度报告、日志记录、训练期间检查嵌入)并做出决策(例如提前停止)。
特别是,Trainer 使用 TrainerControl
,可以通过回调来影响它以停止训练、保存模型、评估或记录,以及 TrainerState
,它在训练期间跟踪一些训练循环指标,例如到目前为止的训练步数。
SetFit 依赖于 transformers
中实现的回调,如 transformers
文档 此处 所述。
默认回调
SetFit 使用 TrainingArguments.report_to
参数来指定应启用哪些内置回调。此参数默认为 "all"
,这意味着还将启用来自 transformers
的所有已安装的第三方回调。例如 TensorBoardCallback
或 WandbCallback
。
除此之外,PrinterCallback
或 ProgressCallback
始终启用以显示训练进度,并且 DefaultFlowCallback
也始终启用以正确更新 TrainerControl
。
使用回调
如前所述,您可以使用 TrainingArguments.report_to
来精确指定您想要启用的回调。例如
from setfit import TrainingArguments
args = TrainingArguments(
...,
report_to="wandb",
...,
)
# or
args = TrainingArguments(
...,
report_to=["wandb", "tensorboard"],
...,
)
您还可以使用 Trainer.add_callback()、Trainer.pop_callback() 和 Trainer.remove_callback() 来影响 trainer 回调,您可以通过 Trainer 初始化来指定回调,例如:
from setfit import Trainer
...
trainer = Trainer(
model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()
自定义回调
SetFit 以与 transformers
相同的方式支持自定义回调:通过子类化 TrainerCallback
。此类实现了许多可以被重写的 on_...
方法。例如,以下脚本显示了一个自定义回调,该回调在训练期间保存训练和评估嵌入的 tSNE 图。
import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
class EmbeddingPlotCallback(TrainerCallback):
"""Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
os.makedirs("logs", exist_ok=True)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
train_embeddings = model.encode(train_dataset["text"])
eval_embeddings = model.encode(eval_dataset["text"])
fig, (train_ax, eval_ax) = plt.subplots(ncols=2)
train_X = TSNE(n_components=2).fit_transform(train_embeddings)
train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
train_ax.set_title("Training embeddings")
eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
eval_ax.set_title("Evaluation embeddings")
fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
fig.savefig(f"logs/step_{state.global_step}.png")
与
trainer = Trainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, callbacks=[EmbeddingPlotCallback()] ) trainer.train()
EmbeddingPlotCallback
中的 on_evaluate
将在每次评估调用时触发。在本例中,它导致绘制了以下图形
步骤 20 | 步骤 40 |
---|---|
步骤 60 | 步骤 80 |