SetFit 文档

快速入门

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

快速入门

本快速入门旨在为准备深入代码并查看如何训练和使用 🤗 SetFit 模型的示例的开发人员而设计。我们建议从本快速入门开始,然后继续阅读教程操作指南以获取更多资料。此外,概念指南有助于解释 SetFit 的确切工作原理。

首先安装 🤗 SetFit

pip install setfit

如果您有支持 CUDA 的显卡,建议安装支持 CUDA 的 torch,以更快地进行训练和推理

pip install torch --index-url https://download.pytorch.org/whl/cu118

SetFit

SetFit 是一个高效的框架,可以使用少量训练数据训练低延迟文本分类模型。在本快速入门中,您将学习如何训练 SetFit 模型、如何使用它执行推理以及如何将其保存到 Hugging Face Hub。

训练

在本节中,您将加载一个 Sentence Transformer 模型,并进一步对其进行微调,以将电影评论分类为正面或负面。要训练模型,我们将需要准备以下三项:1)模型,2)数据集,以及 3)训练参数

1. 使用我们选择的 Sentence Transformer 模型初始化 SetFit 模型。考虑使用 MTEB 排行榜来指导您决定选择哪个 Sentence Transformer 模型。我们将使用 BAAI/bge-small-en-v1.5,这是一个小型但性能良好的模型。

>>> from setfit import SetFitModel

>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

2a. 接下来,加载 SetFit/sst2 数据集的 “train” 和 “test” 拆分。请注意,数据集具有 "text""label" 列:这正是 🤗 SetFit 期望的格式。如果您的数据集具有不同的列,则可以使用 Trainer 的 column_mapping 参数(在步骤 4 中)将列名映射到 "text""label"

>>> from datasets import load_dataset

>>> dataset = load_dataset("SetFit/sst2")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 6920
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1821
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 872
    })
})

2b. 在真实世界的场景中,拥有约 7,000 个高质量标记的训练样本是非常不常见的,因此我们将大量缩小训练数据集,以便更好地了解 🤗 SetFit 在实际设置中的工作方式。具体来说,sample_dataset 函数将为每个类仅采样 8 个样本。测试集不受影响,以便更好地评估。

>>> from setfit import sample_dataset

>>> train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
>>> train_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 16
})
>>> test_dataset = dataset["test"]
>>> test_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 1821
})

2c. 我们可以将数据集中的标签应用于模型,以便预测输出可读的类。您也可以直接将标签提供给 SetFitModel.from_pretrained()

>>> model.labels = ["negative", "positive"]

3. 准备用于训练的 TrainingArguments。请注意,使用 🤗 SetFit 进行训练在幕后包含两个阶段:微调嵌入训练分类头。因此,某些训练参数可以是元组,其中两个值分别用于两个阶段。

num_epochsmax_steps 参数通常用于增加和减少总训练步数。请记住,使用 SetFit,更多数据而非更多训练可以获得更好的性能!如果您有大量数据,请不要害怕训练时间少于 1 个 epoch。

>>> from setfit import TrainingArguments

>>> args = TrainingArguments(
...     batch_size=32,
...     num_epochs=10,
... )

4. 初始化 Trainer 并执行训练。

>>> from setfit import Trainer

>>> trainer = Trainer(
...     model=model,
...     args=args,
...     train_dataset=train_dataset,
... )
>>> trainer.train()
***** Running training *****
  Num examples = 5
  Num epochs = 10
  Total optimization steps = 50
  Total train batch size = 32
{'embedding_loss': 0.2077, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.2}                                                                                                                
{'embedding_loss': 0.0097, 'learning_rate': 0.0, 'epoch': 10.0}                                                                                                                                 
{'train_runtime': 14.705, 'train_samples_per_second': 108.807, 'train_steps_per_second': 3.4, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00,  5.70it/s]

5. 使用提供的测试数据集执行评估。

>>> trainer.evaluate(test_dataset)
***** Running evaluation *****
{'accuracy': 0.8511806699615596}

随意尝试增加每个类的样本数量,以观察准确性的提高。作为一个挑战,您可以尝试调整每个类的样本数、学习率、epoch 数、最大步数和基础 Sentence Transformer 模型,以尝试使用极少量的数据将准确率提高到 90% 以上。

保存 🤗 SetFit 模型

训练后,您可以将 🤗 SetFit 模型保存到本地文件系统或 Hugging Face Hub。使用 SetFitModel.save_pretrained() 并提供 save_directory,将模型保存到本地目录

>>> model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")

或者,使用 SetFitModel.push_to_hub() 并提供 repo_id,将模型推送到 Hugging Face Hub

>>> model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

加载 🤗 SetFit 模型

可以使用 SetFitModel.from_pretrained() 加载 🤗 SetFit 模型,方法是提供 1) Hugging Face Hub 的 repo_id 或 2) 本地目录的路径

>>> model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub

>>> model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

推理

一旦 🤗 SetFit 模型经过训练,就可以使用它进行推理,以使用 SetFitModel.predict()SetFitModel.call() 对评论进行分类

>>> preds = model.predict([
...     "It's a charming and often affecting journey.",
...     "It's slow -- very, very slow.",
...     "A sometimes tedious film.",
... ])
>>> preds
['positive' 'negative' 'negative']

这些预测依赖于 model.labels。如果未设置,它将返回训练期间使用的格式的预测,例如 tensor([1, 0, 0])

下一步是什么?

您已完成 🤗 SetFit 快速入门!您可以训练、保存、加载和使用 🤗 SetFit 模型执行推理!

对于您的下一步,请查看我们的操作指南,了解如何执行更具体的操作,例如超参数搜索、知识蒸馏或零样本文本分类。如果您有兴趣了解有关 🤗 SetFit 工作原理的更多信息,请喝杯咖啡并阅读我们的概念指南

端到端

此代码段显示了端到端示例中的整个快速入门

from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset

# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=["negative", "positive"])

# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]

# Preparing the training arguments
args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}

# Saving the trained model
model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
# or
model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

# Loading a trained model
model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub
# or
model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

# Performing inference
preds = model.predict([
    "It's a charming and often affecting journey.",
    "It's slow -- very, very slow.",
    "A sometimes tedious film.",
])
print(preds)
# => ["positive", "negative", "negative"]
< > 在 GitHub 上更新