SetFit 文档

快速入门

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

快速入门

本快速入门面向准备深入代码并查看如何训练和使用 🤗 SetFit 模型的开发人员。我们建议从本快速入门开始,然后继续阅读教程操作指南以获取更多资料。此外,概念指南有助于准确解释 SetFit 的工作原理。

首先安装 🤗 SetFit

pip install setfit

如果您有支持 CUDA 的显卡,建议安装支持 CUDA 的 torch,以便更快地进行训练和推理

pip install torch --index-url https://download.pytorch.org/whl/cu118

SetFit

SetFit 是一个高效的框架,可以使用少量训练数据训练低延迟文本分类模型。在本快速入门中,您将学习如何训练 SetFit 模型、如何使用它进行推理以及如何将其保存到 Hugging Face Hub。

训练

在本节中,您将加载一个句子转换器模型,并进一步对其进行微调,以将电影评论分类为正面或负面。为了训练模型,我们需要准备以下三项:1) 一个模型,2) 一个数据集,以及 3) 训练参数

1. 使用我们选择的句子转换器模型初始化一个 SetFit 模型。考虑使用MTEB 排行榜来指导您选择哪个句子转换器模型。我们将使用BAAI/bge-small-en-v1.5,这是一个小巧但性能良好的模型。

>>> from setfit import SetFitModel

>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

2a. 接下来,加载SetFit/sst2数据集的“训练”和“测试”拆分。请注意,数据集有 "text""label" 列:这正是 🤗 SetFit 期望的格式。如果您的数据集有不同的列,那么您可以在第 4 步中使用 Trainer 的 `column_mapping` 参数将列名映射到 "text""label"

>>> from datasets import load_dataset

>>> dataset = load_dataset("SetFit/sst2")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 6920
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1821
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 872
    })
})

2b. 在实际场景中,拥有约 7,000 个高质量标记训练样本是非常罕见的,因此我们将大幅缩小训练数据集,以便更好地了解 🤗 SetFit 在实际设置中如何工作。具体来说,`sample_dataset` 函数将为每个类别仅采样 8 个样本。测试集不受影响,以便更好地进行评估。

>>> from setfit import sample_dataset

>>> train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
>>> train_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 16
})
>>> test_dataset = dataset["test"]
>>> test_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 1821
})

2c. 我们可以将数据集中的标签应用于模型,这样预测输出就可以是可读的类别。您也可以直接将标签提供给 `SetFitModel.from_pretrained()`。

>>> model.labels = ["negative", "positive"]

3. 准备训练参数以进行训练。请注意,使用 🤗 SetFit 进行训练在幕后包含两个阶段:微调嵌入训练分类头。因此,某些训练参数可以是元组,其中两个值分别用于这两个阶段。

`num_epochs` 和 `max_steps` 参数通常用于增加和减少总训练步数。请注意,使用 SetFit 时,更好的性能是通过更多数据,而不是更多训练来实现的!如果您有大量数据,即使训练不到 1 个 epoch 也无需担心。

>>> from setfit import TrainingArguments

>>> args = TrainingArguments(
...     batch_size=32,
...     num_epochs=10,
... )

4. 初始化训练器并执行训练。

>>> from setfit import Trainer

>>> trainer = Trainer(
...     model=model,
...     args=args,
...     train_dataset=train_dataset,
... )
>>> trainer.train()
***** Running training *****
  Num examples = 5
  Num epochs = 10
  Total optimization steps = 50
  Total train batch size = 32
{'embedding_loss': 0.2077, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.2}                                                                                                                
{'embedding_loss': 0.0097, 'learning_rate': 0.0, 'epoch': 10.0}                                                                                                                                 
{'train_runtime': 14.705, 'train_samples_per_second': 108.807, 'train_steps_per_second': 3.4, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00,  5.70it/s]

5. 使用提供的测试数据集进行评估。

>>> trainer.evaluate(test_dataset)
***** Running evaluation *****
{'accuracy': 0.8511806699615596}

随意尝试增加每个类别的样本数量,以观察准确性的提高。作为一项挑战,您可以尝试调整每个类别的样本数、学习率、 epoch 数、最大步数以及基础句子转换器模型,以尝试在少量数据下将准确性提高到 90% 以上。

保存 🤗 SetFit 模型

训练后,您可以将 🤗 SetFit 模型保存到本地文件系统或 Hugging Face Hub。通过提供 `save_directory`,使用 `SetFitModel.save_pretrained()` 将模型保存到本地目录

>>> model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")

或者,通过提供 `repo_id`,使用 `SetFitModel.push_to_hub()` 将模型推送到 Hugging Face Hub

>>> model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

加载 🤗 SetFit 模型

可以通过提供 1) 来自 Hugging Face Hub 的 `repo_id` 或 2) 本地目录的路径来使用 `SetFitModel.from_pretrained()` 加载 🤗 SetFit 模型

>>> model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub

>>> model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

推理

训练好 🤗 SetFit 模型后,可以使用SetFitModel.predict()SetFitModel.call()对评论进行分类以进行推理

>>> preds = model.predict([
...     "It's a charming and often affecting journey.",
...     "It's slow -- very, very slow.",
...     "A sometimes tedious film.",
... ])
>>> preds
['positive' 'negative' 'negative']

这些预测依赖于 `model.labels`。如果未设置,它将以训练期间使用的格式返回预测,例如 `tensor([1, 0, 0])`。

下一步是什么?

您已完成 🤗 SetFit 快速入门!您现在可以训练、保存、加载 🤗 SetFit 模型并进行推理!

接下来的步骤,请查阅我们的操作指南,了解如何进行超参数搜索、知识蒸馏或零样本文本分类等更具体的操作。如果您有兴趣深入了解 🤗 SetFit 的工作原理,请泡杯咖啡,阅读我们的概念指南

端到端

此代码片段展示了整个快速入门的端到端示例

from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset

# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=["negative", "positive"])

# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]

# Preparing the training arguments
args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}

# Saving the trained model
model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
# or
model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

# Loading a trained model
model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub
# or
model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

# Performing inference
preds = model.predict([
    "It's a charming and often affecting journey.",
    "It's slow -- very, very slow.",
    "A sometimes tedious film.",
])
print(preds)
# => ["positive", "negative", "negative"]
< > 在 GitHub 上更新