SetFit 文档

零样本文本分类

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

零样本文本分类

虽然 SetFit 是为少样本学习而设计的,但该方法也可以应用于没有标记数据的场景。主要的技巧是创建类似于分类任务的*合成示例*,然后使用它们训练 SetFit 模型。

值得注意的是,这种简单的技术通常优于 🤗 Transformers 中的零样本管道,并且预测速度可以快 5 倍(或更多)!

在本教程中,我们将探讨如何

  • SetFit 可以应用于零样本分类
  • 添加合成示例还可以为少样本分类提供性能提升。

设置

如果您在 Colab 或其他云平台上运行此 Notebook,则需要安装 `setfit` 库。取消注释以下单元格并运行它

# %pip install setfit matplotlib

为了基准测试“零样本”方法的性能,我们将使用以下数据集和预训练模型

dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"

接下来,我们将从 Hugging Face Hub 下载参考数据集

from datasets import load_dataset

reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})

现在我们已经设置好,让我们创建一些合成数据进行训练!

创建合成数据集

我们需要做的第一件事是创建一个合成示例数据集。在 `setfit` 中,我们可以通过将 `get_templated_dataset()` 函数应用于虚拟数据集来做到这一点。此函数需要几个主要内容

  • 一个用于分类的候选标签列表。我们将在此处使用参考数据集中的标签,但这可以是与任务和当前数据集相关的任何内容。
  • 一个用于生成示例的模板。默认情况下,它是 `"This sentence is {}"`,其中 `{}` 将由一个候选标签填充
  • 一个样本大小 $N$,它将为每个类创建 $N$ 个合成示例。我们发现 $N=8$ 通常效果最好。

有了这些信息,我们首先从数据集中提取一些候选标签

# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Hugging Face Hub 上的一些数据集在标签列中没有 `ClassLabel` 特性。在这种情况下,您应该首先计算 id2label 映射,然后手动计算候选标签,如下所示

def get_id2label(dataset):
    # The column with the label names
    label_names = dataset.unique("label_text")
    # The column with the label IDs
    label_ids = dataset.unique("label")
    id2label = dict(zip(label_ids, label_names))
    # Sort by label ID
    return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}

id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())

现在我们有了标签,创建合成示例就变得很简单了

from datasets import Dataset
from setfit import get_templated_dataset

# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 48
})

您可能会发现,通过将 `template` 参数从默认的 `"The sentence is {}"` 调整为 `"This sentence is {}"` 或 `"This example is {}"` 等变体,可以获得更好的性能。

由于我们的数据集有 6 个类,我们选择了 8 的样本大小,因此我们的合成数据集包含 $6\times 8=48$ 个示例。如果我们查看一些示例

train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
  'This sentence is fear',
  'This sentence is joy'],
 'label': [2, 4, 1]}

我们可以看到每个输入都采用模板的形式,并具有与之关联的相应标签。

我们不要在这些示例上训练 SetFit 模型!

微调模型

要训练 SetFit 模型,首先要从 Hub 下载预训练检查点。我们可以通过使用 `SetFitModel.from_pretrained()` 方法来做到这一点

from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id)

在这里,我们从 Hub 下载了一个预训练的 Sentence Transformer,并添加了一个逻辑分类头来创建 SetFit 模型。如消息所示,我们需要在一些标记示例上训练这个模型。我们可以通过使用 Trainer 类来做到这一点

from setfit import Trainer

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)

现在我们已经创建了一个训练器,我们可以训练它了!同时,让我们记录训练和评估模型所需的时间

%%time
trainer.train()
zeroshot_metrics = trainer.evaluate()
zeroshot_metrics
***** Running training *****
  Num examples = 1920
  Num epochs = 1
  Total optimization steps = 120
  Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s

太好了,现在我们有了一个参考分数,让我们与 🤗 Transformers 中的零样本管道进行比较。

与 🤗 Transformers 中的零样本管道进行比较

🤗 Transformers 提供了一个零样本管道,将文本分类构建为自然语言推理任务。让我们加载管道并将其放在 GPU 上以实现快速推理

from transformers import pipeline

pipe = pipeline("zero-shot-classification", device=0)

现在我们有了模型,让我们生成一些预测。我们将使用与 SetFit 相同的候选标签,并增加批处理大小以加快速度

%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s

请注意,这比 SetFit 生成预测的时间长了近 5 倍!好的,那么它的表现如何呢?由于每个预测都是按分数排名的标签名称字典

zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
 'scores': [0.7367985844612122,
  0.10041674226522446,
  0.09770156443119049,
  0.05880110710859299,
  0.004266355652362108,
  0.0020156768150627613]}

我们可以使用 `label` 列中的 `str2int()` 函数将它们转换为整数。

preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]

**注意:** 如前所述,如果您使用的数据集的标签列没有 `ClassLabel` 特性,则需要手动计算标签映射,例如

id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]

最后一步是使用 🤗 Evaluate 计算准确率

import evaluate

metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}

与 SetFit 相比,这种方法的性能明显更差。让我们通过将合成示例与一些标记示例相结合来结束我们的分析。

用合成示例增强标记数据

如果您有一些标记示例,添加合成数据通常可以提高性能。为了模拟这一点,我们首先从参考数据集中采样 8 个标记示例

from setfit import sample_dataset

train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 48
})

预热一下,我们将用这些真实标签训练一个 SetFit 模型

model = SetFitModel.from_pretrained(model_id)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}

请注意,对于这个特定的数据集,使用真实标签的性能*比*使用合成示例训练的性能*更差*!在我们的实验中,我们发现差异很大程度上取决于具体的数据集。由于 SetFit 模型训练速度快,您总是可以尝试两种方法并选择最佳的一种。

无论如何,现在让我们向训练集中添加一些合成示例

augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 96
})

和以前一样,我们可以用增强数据集训练和评估 SetFit

model = SetFitModel.from_pretrained(model_id)

trainer = Trainer(
    model=model,
    train_dataset=augmented_dataset,
    eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}

太好了,这大大提升了我们的性能,比纯粹的合成示例提高了几个百分点。

让我们绘制最终结果进行比较

import pandas as pd

df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");                                       

setfit_zero_shot_results

< > 在 GitHub 上更新