SetFit 文档
零样本文本分类
并获得增强的文档体验
开始使用
零样本文本分类
虽然 SetFit 是为少样本学习而设计的,但该方法也可以应用于没有标记数据的场景。主要的技巧是创建类似于分类任务的*合成示例*,然后使用它们训练 SetFit 模型。
值得注意的是,这种简单的技术通常优于 🤗 Transformers 中的零样本管道,并且预测速度可以快 5 倍(或更多)!
在本教程中,我们将探讨如何
- SetFit 可以应用于零样本分类
- 添加合成示例还可以为少样本分类提供性能提升。
设置
如果您在 Colab 或其他云平台上运行此 Notebook,则需要安装 `setfit` 库。取消注释以下单元格并运行它
# %pip install setfit matplotlib
为了基准测试“零样本”方法的性能,我们将使用以下数据集和预训练模型
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
接下来,我们将从 Hugging Face Hub 下载参考数据集
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 16000
})
validation: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
})
现在我们已经设置好,让我们创建一些合成数据进行训练!
创建合成数据集
我们需要做的第一件事是创建一个合成示例数据集。在 `setfit` 中,我们可以通过将 `get_templated_dataset()` 函数应用于虚拟数据集来做到这一点。此函数需要几个主要内容
- 一个用于分类的候选标签列表。我们将在此处使用参考数据集中的标签,但这可以是与任务和当前数据集相关的任何内容。
- 一个用于生成示例的模板。默认情况下,它是 `"This sentence is {}"`,其中 `{}` 将由一个候选标签填充
- 一个样本大小 $N$,它将为每个类创建 $N$ 个合成示例。我们发现 $N=8$ 通常效果最好。
有了这些信息,我们首先从数据集中提取一些候选标签
# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
Hugging Face Hub 上的一些数据集在标签列中没有 `ClassLabel` 特性。在这种情况下,您应该首先计算 id2label 映射,然后手动计算候选标签,如下所示
def get_id2label(dataset):
# The column with the label names
label_names = dataset.unique("label_text")
# The column with the label IDs
label_ids = dataset.unique("label")
id2label = dict(zip(label_ids, label_names))
# Sort by label ID
return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}
id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())
现在我们有了标签,创建合成示例就变得很简单了
from datasets import Dataset
from setfit import get_templated_dataset
# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
您可能会发现,通过将 `template` 参数从默认的 `"The sentence is {}"` 调整为 `"This sentence is {}"` 或 `"This example is {}"` 等变体,可以获得更好的性能。
由于我们的数据集有 6 个类,我们选择了 8 的样本大小,因此我们的合成数据集包含 $6\times 8=48$ 个示例。如果我们查看一些示例
train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
'This sentence is fear',
'This sentence is joy'],
'label': [2, 4, 1]}
我们可以看到每个输入都采用模板的形式,并具有与之关联的相应标签。
我们不要在这些示例上训练 SetFit 模型!
微调模型
要训练 SetFit 模型,首先要从 Hub 下载预训练检查点。我们可以通过使用 `SetFitModel.from_pretrained()` 方法来做到这一点
from setfit import SetFitModel
model = SetFitModel.from_pretrained(model_id)
在这里,我们从 Hub 下载了一个预训练的 Sentence Transformer,并添加了一个逻辑分类头来创建 SetFit 模型。如消息所示,我们需要在一些标记示例上训练这个模型。我们可以通过使用 Trainer 类来做到这一点
from setfit import Trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
现在我们已经创建了一个训练器,我们可以训练它了!同时,让我们记录训练和评估模型所需的时间
%%time trainer.train() zeroshot_metrics = trainer.evaluate() zeroshot_metrics
***** Running training *****
Num examples = 1920
Num epochs = 1
Total optimization steps = 120
Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s
太好了,现在我们有了一个参考分数,让我们与 🤗 Transformers 中的零样本管道进行比较。
与 🤗 Transformers 中的零样本管道进行比较
🤗 Transformers 提供了一个零样本管道,将文本分类构建为自然语言推理任务。让我们加载管道并将其放在 GPU 上以实现快速推理
from transformers import pipeline
pipe = pipeline("zero-shot-classification", device=0)
现在我们有了模型,让我们生成一些预测。我们将使用与 SetFit 相同的候选标签,并增加批处理大小以加快速度
%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s
请注意,这比 SetFit 生成预测的时间长了近 5 倍!好的,那么它的表现如何呢?由于每个预测都是按分数排名的标签名称字典
zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
'scores': [0.7367985844612122,
0.10041674226522446,
0.09770156443119049,
0.05880110710859299,
0.004266355652362108,
0.0020156768150627613]}
我们可以使用 `label` 列中的 `str2int()` 函数将它们转换为整数。
preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]
**注意:** 如前所述,如果您使用的数据集的标签列没有 `ClassLabel` 特性,则需要手动计算标签映射,例如
id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]
最后一步是使用 🤗 Evaluate 计算准确率
import evaluate
metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}
与 SetFit 相比,这种方法的性能明显更差。让我们通过将合成示例与一些标记示例相结合来结束我们的分析。
用合成示例增强标记数据
如果您有一些标记示例,添加合成数据通常可以提高性能。为了模拟这一点,我们首先从参考数据集中采样 8 个标记示例
from setfit import sample_dataset
train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
预热一下,我们将用这些真实标签训练一个 SetFit 模型
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}
请注意,对于这个特定的数据集,使用真实标签的性能*比*使用合成示例训练的性能*更差*!在我们的实验中,我们发现差异很大程度上取决于具体的数据集。由于 SetFit 模型训练速度快,您总是可以尝试两种方法并选择最佳的一种。
无论如何,现在让我们向训练集中添加一些合成示例
augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
features: ['text', 'label'],
num_rows: 96
})
和以前一样,我们可以用增强数据集训练和评估 SetFit
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=augmented_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}
太好了,这大大提升了我们的性能,比纯粹的合成示例提高了几个百分点。
让我们绘制最终结果进行比较
import pandas as pd
df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");