SetFit 文档
零样本文本分类
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
零样本文本分类
您的类名可能已经很好地描述了您想要分类的文本。使用 🤗 SetFit,您可以将这些类名与强大的预训练 Sentence Transformer 模型一起使用,从而无需任何训练样本即可获得一个强大的基线模型。
本指南将向您展示如何执行零样本文本分类。
测试数据集
我们将使用 dair-ai/emotion 数据集来测试零样本模型的性能。
from datasets import load_dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
此数据集将类名存储在数据集 Features
中,因此我们将按如下方式提取类:
classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
否则,我们可以手动设置类列表。
合成数据集
然后,我们可以使用 get_templated_dataset() 根据这些类名合成生成一个虚拟数据集。
from setfit import get_templated_dataset
train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
# features: ['text', 'label'],
# num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}
训练
我们可以像往常一样使用此数据集来训练 SetFit 模型。
from setfit import SetFitModel, Trainer, TrainingArguments
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
args = TrainingArguments(
batch_size=32,
num_epochs=1,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
Num examples = 60
Num epochs = 1
Total optimization steps = 60
Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]
训练后,我们可以评估模型。
metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}
并运行预测。
preds = model.predict([
"i am just feeling cranky and blue",
"i feel incredibly lucky just to be able to talk to her",
"you're pissing me off right now",
"i definitely have thalassophobia, don't get me near water like that",
"i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']
这些预测看起来都正确!
基准
为了表明 SetFit 的零样本性能良好,我们将它与 transformers
中的零样本分类模型进行比较。
from transformers import pipeline
from datasets import load_dataset
import evaluate
# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names
# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]
# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}
凭借 59.1% 的准确率,0-shot SetFit 显著优于 transformers
推荐的零样本模型。
预测延迟
除了获得更高的准确率,SetFit 的速度也快得多。让我们计算 SetFit 使用 BAAI/bge-small-en-v1.5
与 transformers
使用 facebook/bart-large-mnli
的延迟。两项测试均在 GPU 上执行。
import time
start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time
start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence
因此,使用 BAAI/bge-small-en-v1.5
的 SetFit 比使用 facebook/bart-large-mnli
的 transformers
快 67 倍,同时更准确。