SetFit 文档
零样本文本分类
并获得增强的文档体验
开始使用
零样本文本分类
尽管 SetFit 最初是为少样本学习设计的,但该方法也适用于没有标记数据的情况。 主要技巧是创建合成示例,使其类似于分类任务,然后在这些示例上训练 SetFit 模型。
值得注意的是,这种简单技术通常优于 🤗 Transformers 中的零样本流水线,并且可以快 5 倍(或更多)地生成预测!
在本教程中,我们将探讨如何
- 将 SetFit 应用于零样本分类
- 添加合成示例还可以提高少样本分类的性能。
设置
如果您在 Colab 或其他云平台上运行此 Notebook,则需要安装 setfit
库。 取消注释以下单元格并运行它
# %pip install setfit matplotlib
为了基准测试“零样本”方法的性能,我们将使用以下数据集和预训练模型
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
接下来,我们将从 Hugging Face Hub 下载参考数据集
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 16000
})
validation: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
})
现在我们已经设置好了,让我们创建一些合成数据来训练!
创建合成数据集
我们需要做的第一件事是创建一个合成示例数据集。 在 setfit
中,我们可以通过将 get_templated_dataset()
函数应用于虚拟数据集来完成此操作。 此函数需要几个主要内容
- 用于分类的候选标签列表。 我们将在此处使用来自参考数据集的标签,但这可以是与手头的任务和数据集相关的任何内容。
- 用于生成示例的模板。 默认情况下,它是
"This sentence is {}"
,其中{}
将由候选标签之一填充 - 样本大小 $N$,它将为每个类创建 $N$ 个合成示例。 我们发现 $N=8$ 通常效果最佳。
掌握了这些信息,让我们首先从数据集中提取一些候选标签
# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
Hugging Face Hub 上的一些数据集在标签列中没有 ClassLabel
特征。 在这些情况下,您应该通过首先计算 id2label 映射来手动计算候选标签,如下所示
def get_id2label(dataset):
# The column with the label names
label_names = dataset.unique("label_text")
# The column with the label IDs
label_ids = dataset.unique("label")
id2label = dict(zip(label_ids, label_names))
# Sort by label ID
return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}
id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())
现在我们有了标签,创建合成示例就变得很简单了
from datasets import Dataset
from setfit import get_templated_dataset
# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
您可能会发现,通过将 template
参数从默认的 "The sentence is {}"
调整为诸如 "This sentence is {}"
或 "This example is {}"
之类的变体,可以获得更好的性能。
由于我们的数据集有 6 个类,并且我们选择了 8 的样本大小,因此我们的合成数据集包含 $6\times 8=48$ 个示例。 如果我们看一下其中的一些示例
train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
'This sentence is fear',
'This sentence is joy'],
'label': [2, 4, 1]}
我们可以看到,每个输入都采用模板的形式,并具有与之关联的相应标签。
让我们在这些示例上训练一个 SetFit 模型!
微调模型
要训练 SetFit 模型,首先要做的是从 Hub 下载预训练的检查点。 我们可以通过使用 SetFitModel.from_pretrained()
方法来做到这一点
from setfit import SetFitModel
model = SetFitModel.from_pretrained(model_id)
在这里,我们从 Hub 下载了一个预训练的 Sentence Transformer,并添加了一个逻辑分类头来创建 SetFit 模型。 正如消息中所示,我们需要在一些标记示例上训练此模型。 我们可以通过使用 Trainer 类来做到这一点,如下所示
from setfit import Trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
现在我们已经创建了一个训练器,我们可以训练它了! 在我们进行训练的同时,让我们计时训练和评估模型需要多长时间
%%time trainer.train() zeroshot_metrics = trainer.evaluate() zeroshot_metrics
***** Running training *****
Num examples = 1920
Num epochs = 1
Total optimization steps = 120
Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s
太棒了,现在我们有了参考分数,让我们与 🤗 Transformers 的零样本流水线进行比较。
与 🤗 Transformers 的零样本流水线进行比较
🤗 Transformers 提供了一个零样本流水线,该流水线将文本分类构建为自然语言推理任务。 让我们加载流水线并将其放置在 GPU 上以进行快速推理
from transformers import pipeline
pipe = pipeline("zero-shot-classification", device=0)
现在我们有了模型,让我们生成一些预测。 我们将使用与 SetFit 相同的候选标签,并增加批大小以加快速度
%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s
请注意,这比 SetFit 生成预测花费的时间几乎长 5 倍! 好的,那么它的表现如何呢? 由于每个预测都是按分数排序的标签名称字典
zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
'scores': [0.7367985844612122,
0.10041674226522446,
0.09770156443119049,
0.05880110710859299,
0.004266355652362108,
0.0020156768150627613]}
我们可以使用 label
列中的 str2int()
函数将它们转换为整数。
preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]
注意: 如前所述,如果您使用的数据集在标签列中没有 ClassLabel
特征,则需要使用类似以下内容手动计算标签映射
id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]
最后一步是使用 🤗 Evaluate 计算准确率
import evaluate
metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}
与 SetFit 相比,这种方法的性能明显更差。 让我们通过将合成示例与一些标记示例相结合来结束我们的分析。
使用合成示例增强标记数据
如果您有一些标记示例,则添加合成数据通常可以提高性能。 为了模拟这一点,让我们首先从我们的参考数据集中采样 8 个标记示例
from setfit import sample_dataset
train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
为了预热,我们将在这些真实标签上训练一个 SetFit 模型
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}
请注意,对于此特定数据集,使用真实标签的性能比在合成示例上训练更差! 在我们的实验中,我们发现差异很大程度上取决于所讨论的数据集。 由于 SetFit 模型训练速度很快,您可以始终尝试这两种方法并选择最佳方法。
在任何情况下,现在让我们向我们的训练集中添加一些合成示例
augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
features: ['text', 'label'],
num_rows: 96
})
和以前一样,我们可以使用增强的数据集来训练和评估 SetFit
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=augmented_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}
太棒了,这给了我们性能的显着提升,并使我们比纯粹的合成示例高出几个百分点。
让我们绘制最终结果以进行比较
import pandas as pd
df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");