使用 SetFit 在零样本文本分类中进行数据标注的建议
作者:David Berenstein 和 Sara Han Díaz
建议是使标注团队的工作更轻松、更快捷的绝佳方式。这些预选选项将使标注过程更高效,因为他们只需要更正建议即可。在本例中,我们将演示如何使用 SetFit 实现零样本方法,以获取 Argilla 中数据集的一些初始建议,该数据集结合了两个文本分类任务,包括一个LabelQuestion
和一个MultiLabelQuestion
。
Argilla 是一个协作工具,供需要为其项目构建高质量数据集的 AI 工程师和领域专家使用。使用 Argilla,每个人都可以通过使用人工和机器反馈进行更快的的数据整理来构建强大的语言模型。
反馈是数据整理过程中至关重要的部分,Argilla 还提供了一种管理和可视化反馈的方式,以便稍后可以使用整理后的数据来改进语言模型。在本教程中,我们将展示一个如何通过提供建议来简化标注人员工作的真实示例。为了实现这一点,您将学习如何使用 SetFit 训练零样本情感和主题分类器,然后使用它们来为数据集建议标签。
在本教程中,我们将遵循以下步骤
- 在 Argilla 中创建数据集。
- 使用 SetFit 训练零样本分类器。
- 使用训练好的分类器获取数据集的建议。
- 在 Argilla 中可视化建议。
让我们开始吧!
设置
在本教程中,您需要运行 Argilla 服务器。如果您已部署 Argilla,则可以跳过此步骤。否则,您可以按照本指南快速在 HF Spaces 或本地部署 Argilla。完成后,请完成以下步骤
- 使用
pip
安装 Argilla 客户端和所需的第三方库
!pip install argilla
!pip install setfit==1.0.3 transformers==4.40.2 huggingface_hub==0.23.5
- 进行必要的导入
import argilla as rg
from datasets import load_dataset
from setfit import SetFitModel, Trainer, get_templated_dataset
- 如果您使用 Docker 快速启动镜像或 Hugging Face Spaces 运行 Argilla,则需要使用
API_URL
和API_KEY
初始化 Argilla 客户端。
# Replace api_url with your url if using Docker
# Replace api_key if you configured a custom API key
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
配置数据集
在本示例中,我们将加载banking77数据集,这是一个流行的开源数据集,包含银行领域的客户请求。
data = load_dataset("PolyAI/banking77", split="test")
Argilla 使用Dataset
类,它可以轻松创建数据集并管理数据和反馈。首先需要配置Dataset
。在“设置”中,我们可以指定指南、将要添加的注释数据的字段以及标注者的问题。但是,还可以添加更多功能。有关更多信息,请查看Argilla 使用指南。
对于我们的用例,我们需要一个文本字段和两个不同的问题。我们将使用此数据集的原始标签对请求中提到的主题进行多标签分类,并将设置一个标签问题,将请求的情感分类为“正面”、“中性”或“负面”。
settings = rg.Settings(
fields=[rg.TextField(name="text")],
questions=[
rg.MultiLabelQuestion(
name="topics",
title="Select the topic(s) of the request",
labels=data.info.features["label"].names,
visible_labels=10,
),
rg.LabelQuestion(
name="sentiment",
title="What is the sentiment of the message?",
labels=["positive", "neutral", "negative"],
),
],
)
dataset = rg.Dataset(
name="setfit_tutorial_dataset",
settings=settings,
)
dataset.create()
训练模型
现在,我们将使用从 HF 加载的数据以及为数据集配置的标签和问题,为数据集中的每个问题训练一个零样本文本分类模型。如前几节所述,我们将使用SetFit框架对两个分类器中的 Sentence Transformers 进行少样本微调。此外,我们将使用的模型是all-MiniLM-L6-v2,这是一个在 10 亿个句子对数据集上使用对比目标进行微调的句子嵌入模型。
def train_model(question_name, template, multi_label=False):
train_dataset = get_templated_dataset(
candidate_labels=dataset.questions[question_name].labels,
sample_size=8,
template=template,
multi_label=multi_label,
)
# Train a model using the training dataset we just built
if multi_label:
model = SetFitModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2",
multi_target_strategy="one-vs-rest",
)
else:
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(model=model, train_dataset=train_dataset)
trainer.train()
return model
topic_model = train_model(
question_name="topics",
template="The customer request is about {}",
multi_label=True,
)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/topic_model"
# )
sentiment_model = train_model(question_name="sentiment", template="This message is {}", multi_label=False)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/sentiment_model"
# )
进行预测
训练步骤完成后,我们可以对数据进行预测。
def get_predictions(texts, model, question_name):
probas = model.predict_proba(texts, as_numpy=True)
labels = dataset.questions[question_name].labels
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
data = data.map(
lambda batch: {
"topics": list(get_predictions(batch["text"], topic_model, "topics")),
"sentiment": list(get_predictions(batch["text"], sentiment_model, "sentiment")),
},
batched=True,
)
data.to_pandas().head()
将记录日志到 Argilla
有了我们生成的数据和预测,我们现在可以构建记录(标注团队将标注的每个数据项),其中包含我们模型的建议。对于LabelQuestion
,我们将使用获得最高概率分数的标签;对于MultiLabelQuestion
,我们将包含所有得分高于某个阈值的标签。在本例中,我们决定使用2/len(labels)
,但您可以根据自己的数据进行实验,并决定使用更严格或更宽松的阈值。
请注意,更宽松的阈值(接近或等于
1/len(labels)
)将建议更多标签,而更严格的阈值(2 到 3 之间)将选择更少(或没有)标签。
def add_suggestions(record):
suggestions = []
# Get label with max score for sentiment question
sentiment = max(record["sentiment"], key=lambda x: x["score"])["label"]
suggestions.append(rg.Suggestion(question_name="sentiment", value=sentiment))
# Get all labels above a threshold for topics questions
threshold = 2 / len(dataset.questions["topics"].labels)
topics = [label["label"] for label in record["topics"] if label["score"] >= threshold]
if topics:
suggestions.append(rg.Suggestion(question_name="topics", value=topics))
return suggestions
records = [rg.Record(fields={"text": record["text"]}, suggestions=add_suggestions(record)) for record in data]
对结果满意后,我们可以将记录日志到上面配置的数据集中。您现在可以访问 Argilla 中的数据集并可视化建议。
dataset.records.log(records)
以下是带有模型建议的 UI 外观
此外,您还可以将 Argilla 数据集保存到 Hugging Face Hub 并从中加载。有关如何执行此操作的更多信息,请参阅Argilla 文档。
# Export to HuggingFace Hub
dataset.to_hub(repo_id="argilla/my_setfit_dataset")
# Import from HuggingFace Hub
dataset = rg.Dataset.from_hub(repo_id="argilla/my_setfit_dataset")
结论
在本教程中,我们介绍了如何使用 SetFit 库和零样本方法向 Argilla 数据集添加建议。这将有助于提高标注过程的效率,减少标注团队必须做出的决策和编辑次数。
查看以下链接以获取更多资源
< > 在 GitHub 上更新