SetFit 文档

零样本文本分类

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

零样本文本分类

您的类别名称很可能已经是您想要分类的文本的良好描述符。借助 🤗 SetFit,您可以使用这些类别名称和强大的预训练 Sentence Transformer 模型,无需任何训练样本即可获得强大的基线模型。

本指南将向您展示如何执行零样本文本分类。

测试数据集

我们将使用 dair-ai/emotion 数据集来测试我们的零样本模型的性能。

from datasets import load_dataset

test_dataset = load_dataset("dair-ai/emotion", "split", split="test")

此数据集将类别名称存储在数据集 Features 中,因此我们将像这样提取类别

classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

否则,我们可以手动设置类别列表。

合成数据集

然后,我们可以使用 get_templated_dataset() 根据这些类别名称合成生成一个虚拟数据集。

from setfit import get_templated_dataset

train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
#     features: ['text', 'label'],
#     num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}

训练

我们可以像平常一样使用此数据集来训练 SetFit 模型

from setfit import SetFitModel, Trainer, TrainingArguments

model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

args = TrainingArguments(
    batch_size=32,
    num_epochs=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
  Num examples = 60
  Num epochs = 1
  Total optimization steps = 60
  Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}                                                                                 
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}                                                                                 
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}                                                     
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.35it/s]

训练完成后,我们可以评估模型

metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}

并运行预测

preds = model.predict([
    "i am just feeling cranky and blue",
    "i feel incredibly lucky just to be able to talk to her",
    "you're pissing me off right now",
    "i definitely have thalassophobia, don't get me near water like that",
    "i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']

这些预测看起来都正确!

基线

为了展示 SetFit 的零样本性能良好,我们将它与 transformers 的零样本分类模型进行比较。

from transformers import pipeline
from datasets import load_dataset
import evaluate

# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names

# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]

# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}

凭借其 59.1% 的准确率,0-shot SetFit 大大优于 transformers 推荐的零样本模型。

预测延迟

除了获得更高的准确率之外,SetFit 也快得多。 让我们计算一下使用 BAAI/bge-small-en-v1.5 的 SetFit 的延迟,以及使用 facebook/bart-large-mnli 的 transformers 的延迟。 两项测试均在 GPU 上执行。

import time

start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time

start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence

因此,使用 BAAI/bge-small-en-v1.5 的 SetFit 比使用 facebook/bart-large-mnli 的 transformers 快 67 倍,同时更准确

zero_shot_transformers_vs_setfit

< > 更新 在 GitHub 上