SetFit 文档
快速入门
并获取增强的文档体验
开始使用
快速入门
本快速入门旨在为准备深入代码并查看如何训练和使用 🤗 SetFit 模型的示例的开发人员而设计。我们建议从本快速入门开始,然后继续阅读教程或操作指南以获取更多资料。此外,概念指南有助于解释 SetFit 的确切工作原理。
首先安装 🤗 SetFit
pip install setfit
如果您有支持 CUDA 的显卡,建议安装支持 CUDA 的 torch
,以更快地进行训练和推理
pip install torch --index-url https://download.pytorch.org/whl/cu118
SetFit
SetFit 是一个高效的框架,可以使用少量训练数据训练低延迟文本分类模型。在本快速入门中,您将学习如何训练 SetFit 模型、如何使用它执行推理以及如何将其保存到 Hugging Face Hub。
训练
在本节中,您将加载一个 Sentence Transformer 模型,并进一步对其进行微调,以将电影评论分类为正面或负面。要训练模型,我们将需要准备以下三项:1)模型,2)数据集,以及 3)训练参数。
1. 使用我们选择的 Sentence Transformer 模型初始化 SetFit 模型。考虑使用 MTEB 排行榜来指导您决定选择哪个 Sentence Transformer 模型。我们将使用 BAAI/bge-small-en-v1.5,这是一个小型但性能良好的模型。
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
2a. 接下来,加载 SetFit/sst2 数据集的 “train” 和 “test” 拆分。请注意,数据集具有 "text"
和 "label"
列:这正是 🤗 SetFit 期望的格式。如果您的数据集具有不同的列,则可以使用 Trainer 的 column_mapping 参数(在步骤 4 中)将列名映射到 "text"
和 "label"
。
>>> from datasets import load_dataset
>>> dataset = load_dataset("SetFit/sst2")
>>> dataset
DatasetDict({
train: Dataset({
features: ['text', 'label', 'label_text'],
num_rows: 6920
})
test: Dataset({
features: ['text', 'label', 'label_text'],
num_rows: 1821
})
validation: Dataset({
features: ['text', 'label', 'label_text'],
num_rows: 872
})
})
2b. 在真实世界的场景中,拥有约 7,000 个高质量标记的训练样本是非常不常见的,因此我们将大量缩小训练数据集,以便更好地了解 🤗 SetFit 在实际设置中的工作方式。具体来说,sample_dataset
函数将为每个类仅采样 8 个样本。测试集不受影响,以便更好地评估。
>>> from setfit import sample_dataset
>>> train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
>>> train_dataset
Dataset({
features: ['text', 'label', 'label_text'],
num_rows: 16
})
>>> test_dataset = dataset["test"]
>>> test_dataset
Dataset({
features: ['text', 'label', 'label_text'],
num_rows: 1821
})
2c. 我们可以将数据集中的标签应用于模型,以便预测输出可读的类。您也可以直接将标签提供给 SetFitModel.from_pretrained()
。
>>> model.labels = ["negative", "positive"]
3. 准备用于训练的 TrainingArguments。请注意,使用 🤗 SetFit 进行训练在幕后包含两个阶段:微调嵌入和训练分类头。因此,某些训练参数可以是元组,其中两个值分别用于两个阶段。
num_epochs
和 max_steps
参数通常用于增加和减少总训练步数。请记住,使用 SetFit,更多数据而非更多训练可以获得更好的性能!如果您有大量数据,请不要害怕训练时间少于 1 个 epoch。
>>> from setfit import TrainingArguments
>>> args = TrainingArguments(
... batch_size=32,
... num_epochs=10,
... )
4. 初始化 Trainer 并执行训练。
>>> from setfit import Trainer
>>> trainer = Trainer(
... model=model,
... args=args,
... train_dataset=train_dataset,
... )
>>> trainer.train()
***** Running training *****
Num examples = 5
Num epochs = 10
Total optimization steps = 50
Total train batch size = 32
{'embedding_loss': 0.2077, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.2}
{'embedding_loss': 0.0097, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 14.705, 'train_samples_per_second': 108.807, 'train_steps_per_second': 3.4, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00, 5.70it/s]
5. 使用提供的测试数据集执行评估。
>>> trainer.evaluate(test_dataset)
***** Running evaluation *****
{'accuracy': 0.8511806699615596}
随意尝试增加每个类的样本数量,以观察准确性的提高。作为一个挑战,您可以尝试调整每个类的样本数、学习率、epoch 数、最大步数和基础 Sentence Transformer 模型,以尝试使用极少量的数据将准确率提高到 90% 以上。
保存 🤗 SetFit 模型
训练后,您可以将 🤗 SetFit 模型保存到本地文件系统或 Hugging Face Hub。使用 SetFitModel.save_pretrained()
并提供 save_directory
,将模型保存到本地目录
>>> model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
或者,使用 SetFitModel.push_to_hub()
并提供 repo_id
,将模型推送到 Hugging Face Hub
>>> model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")
加载 🤗 SetFit 模型
可以使用 SetFitModel.from_pretrained()
加载 🤗 SetFit 模型,方法是提供 1) Hugging Face Hub 的 repo_id
或 2) 本地目录的路径
>>> model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub
>>> model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory
推理
一旦 🤗 SetFit 模型经过训练,就可以使用它进行推理,以使用 SetFitModel.predict() 或 SetFitModel.call() 对评论进行分类
>>> preds = model.predict([
... "It's a charming and often affecting journey.",
... "It's slow -- very, very slow.",
... "A sometimes tedious film.",
... ])
>>> preds
['positive' 'negative' 'negative']
这些预测依赖于 model.labels
。如果未设置,它将返回训练期间使用的格式的预测,例如 tensor([1, 0, 0])
。
下一步是什么?
您已完成 🤗 SetFit 快速入门!您可以训练、保存、加载和使用 🤗 SetFit 模型执行推理!
对于您的下一步,请查看我们的操作指南,了解如何执行更具体的操作,例如超参数搜索、知识蒸馏或零样本文本分类。如果您有兴趣了解有关 🤗 SetFit 工作原理的更多信息,请喝杯咖啡并阅读我们的概念指南!
端到端
此代码段显示了端到端示例中的整个快速入门
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset
# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=["negative", "positive"])
# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]
# Preparing the training arguments
args = TrainingArguments(
batch_size=32,
num_epochs=10,
)
# Preparing the trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
trainer.train()
# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}
# Saving the trained model
model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
# or
model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")
# Loading a trained model
model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub
# or
model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory
# Performing inference
preds = model.predict([
"It's a charming and often affecting journey.",
"It's slow -- very, very slow.",
"A sometimes tedious film.",
])
print(preds)
# => ["positive", "negative", "negative"]