SetFit 文档

零样本文本分类

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

零样本文本分类

尽管 SetFit 最初是为少样本学习设计的,但该方法也适用于没有标记数据的情况。 主要技巧是创建合成示例,使其类似于分类任务,然后在这些示例上训练 SetFit 模型。

值得注意的是,这种简单技术通常优于 🤗 Transformers 中的零样本流水线,并且可以快 5 倍(或更多)地生成预测!

在本教程中,我们将探讨如何

  • 将 SetFit 应用于零样本分类
  • 添加合成示例还可以提高少样本分类的性能。

设置

如果您在 Colab 或其他云平台上运行此 Notebook,则需要安装 setfit 库。 取消注释以下单元格并运行它

# %pip install setfit matplotlib

为了基准测试“零样本”方法的性能,我们将使用以下数据集和预训练模型

dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"

接下来,我们将从 Hugging Face Hub 下载参考数据集

from datasets import load_dataset

reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})

现在我们已经设置好了,让我们创建一些合成数据来训练!

创建合成数据集

我们需要做的第一件事是创建一个合成示例数据集。 在 setfit 中,我们可以通过将 get_templated_dataset() 函数应用于虚拟数据集来完成此操作。 此函数需要几个主要内容

  • 用于分类的候选标签列表。 我们将在此处使用来自参考数据集的标签,但这可以是与手头的任务和数据集相关的任何内容。
  • 用于生成示例的模板。 默认情况下,它是 "This sentence is {}",其中 {} 将由候选标签之一填充
  • 样本大小 $N$,它将为每个类创建 $N$ 个合成示例。 我们发现 $N=8$ 通常效果最佳。

掌握了这些信息,让我们首先从数据集中提取一些候选标签

# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Hugging Face Hub 上的一些数据集在标签列中没有 ClassLabel 特征。 在这些情况下,您应该通过首先计算 id2label 映射来手动计算候选标签,如下所示

def get_id2label(dataset):
    # The column with the label names
    label_names = dataset.unique("label_text")
    # The column with the label IDs
    label_ids = dataset.unique("label")
    id2label = dict(zip(label_ids, label_names))
    # Sort by label ID
    return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}

id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())

现在我们有了标签,创建合成示例就变得很简单了

from datasets import Dataset
from setfit import get_templated_dataset

# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 48
})

您可能会发现,通过将 template 参数从默认的 "The sentence is {}" 调整为诸如 "This sentence is {}""This example is {}" 之类的变体,可以获得更好的性能。

由于我们的数据集有 6 个类,并且我们选择了 8 的样本大小,因此我们的合成数据集包含 $6\times 8=48$ 个示例。 如果我们看一下其中的一些示例

train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
  'This sentence is fear',
  'This sentence is joy'],
 'label': [2, 4, 1]}

我们可以看到,每个输入都采用模板的形式,并具有与之关联的相应标签。

让我们在这些示例上训练一个 SetFit 模型!

微调模型

要训练 SetFit 模型,首先要做的是从 Hub 下载预训练的检查点。 我们可以通过使用 SetFitModel.from_pretrained() 方法来做到这一点

from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id)

在这里,我们从 Hub 下载了一个预训练的 Sentence Transformer,并添加了一个逻辑分类头来创建 SetFit 模型。 正如消息中所示,我们需要在一些标记示例上训练此模型。 我们可以通过使用 Trainer 类来做到这一点,如下所示

from setfit import Trainer

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)

现在我们已经创建了一个训练器,我们可以训练它了! 在我们进行训练的同时,让我们计时训练和评估模型需要多长时间

%%time
trainer.train()
zeroshot_metrics = trainer.evaluate()
zeroshot_metrics
***** Running training *****
  Num examples = 1920
  Num epochs = 1
  Total optimization steps = 120
  Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s

太棒了,现在我们有了参考分数,让我们与 🤗 Transformers 的零样本流水线进行比较。

与 🤗 Transformers 的零样本流水线进行比较

🤗 Transformers 提供了一个零样本流水线,该流水线将文本分类构建为自然语言推理任务。 让我们加载流水线并将其放置在 GPU 上以进行快速推理

from transformers import pipeline

pipe = pipeline("zero-shot-classification", device=0)

现在我们有了模型,让我们生成一些预测。 我们将使用与 SetFit 相同的候选标签,并增加批大小以加快速度

%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s

请注意,这比 SetFit 生成预测花费的时间几乎长 5 倍! 好的,那么它的表现如何呢? 由于每个预测都是按分数排序的标签名称字典

zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
 'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
 'scores': [0.7367985844612122,
  0.10041674226522446,
  0.09770156443119049,
  0.05880110710859299,
  0.004266355652362108,
  0.0020156768150627613]}

我们可以使用 label 列中的 str2int() 函数将它们转换为整数。

preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]

注意: 如前所述,如果您使用的数据集在标签列中没有 ClassLabel 特征,则需要使用类似以下内容手动计算标签映射

id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]

最后一步是使用 🤗 Evaluate 计算准确率

import evaluate

metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}

与 SetFit 相比,这种方法的性能明显更差。 让我们通过将合成示例与一些标记示例相结合来结束我们的分析。

使用合成示例增强标记数据

如果您有一些标记示例,则添加合成数据通常可以提高性能。 为了模拟这一点,让我们首先从我们的参考数据集中采样 8 个标记示例

from setfit import sample_dataset

train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 48
})

为了预热,我们将在这些真实标签上训练一个 SetFit 模型

model = SetFitModel.from_pretrained(model_id)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}

请注意,对于此特定数据集,使用真实标签的性能在合成示例上训练更差! 在我们的实验中,我们发现差异很大程度上取决于所讨论的数据集。 由于 SetFit 模型训练速度很快,您可以始终尝试这两种方法并选择最佳方法。

在任何情况下,现在让我们向我们的训练集中添加一些合成示例

augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
    features: ['text', 'label'],
    num_rows: 96
})

和以前一样,我们可以使用增强的数据集来训练和评估 SetFit

model = SetFitModel.from_pretrained(model_id)

trainer = Trainer(
    model=model,
    train_dataset=augmented_dataset,
    eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}

太棒了,这给了我们性能的显着提升,并使我们比纯粹的合成示例高出几个百分点。

让我们绘制最终结果以进行比较

import pandas as pd

df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");                                       

setfit_zero_shot_results

< > GitHub 上更新