开源 AI 食谱文档
在零样本文字分类中使用 SetFit 进行数据标注的建议
并获得增强的文档体验
开始使用
在零样本文字分类中使用 SetFit 进行数据标注的建议
作者:David Berenstein 和 Sara Han Díaz
建议是让标注团队工作更轻松、更快捷的绝佳方式。这些预选选项将使标注过程更高效,因为他们只需要纠正建议即可。在本例中,我们将演示如何使用 SetFit 实现零样本方法,为 Argilla 中的一个数据集获取一些初步建议,该数据集结合了两种文本分类任务,包括 LabelQuestion
和 MultiLabelQuestion
。
Argilla 是一个面向 AI 工程师和领域专家的协作工具,他们需要为自己的项目构建高质量的数据集。通过使用 Argilla,每个人都可以通过结合人工和机器反馈,加快数据整理过程,从而构建出稳健的语言模型。
反馈是数据整理过程中的关键部分,Argilla 也提供了一种管理和可视化反馈的方法,以便后续使用整理好的数据来改进语言模型。在本教程中,我们将展示一个真实案例,说明如何通过提供建议来简化标注者的工作。为实现这一目标,您将学习如何使用 SetFit 训练零样本情感和主题分类器,然后用它们来为数据集建议标签。
在本教程中,我们将遵循以下步骤:
- 在 Argilla 中创建一个数据集。
- 使用 SetFit 训练零样本分类器。
- 使用训练好的分类器为数据集获取建议。
- 在 Argilla 中可视化建议。
让我们开始吧!
环境设置
在本教程中,您需要运行一个 Argilla 服务器。如果您已经部署了 Argilla,可以跳过此步骤。否则,您可以按照本指南在 HF Spaces 或本地快速部署 Argilla。完成后,请执行以下步骤:
- 使用
pip
安装 Argilla 客户端和所需的第三方库
!pip install argilla
!pip install setfit==1.0.3 transformers==4.40.2 huggingface_hub==0.23.5
- 导入必要的库
import argilla as rg
from datasets import load_dataset
from setfit import SetFitModel, Trainer, get_templated_dataset
- 如果您使用 Docker 快速启动镜像或 Hugging Face Spaces 运行 Argilla,您需要使用
API_URL
和API_KEY
初始化 Argilla 客户端
# Replace api_url with your url if using Docker
# Replace api_key if you configured a custom API key
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
配置数据集
在本示例中,我们将加载 banking77 数据集,这是一个流行的开源数据集,包含银行领域的客户请求。
data = load_dataset("PolyAI/banking77", split="test")
Argilla 使用 Dataset
类,这使您可以轻松创建数据集并管理数据和反馈。首先需要对 Dataset
进行配置。在 Settings
中,我们可以指定标注*指南*、待标注数据将添加到的*字段*以及给标注者的*问题*。此外,还可以添加更多功能。更多信息,请查看 Argilla 操作指南。
对于我们的用例,我们需要一个文本字段和两个不同的问题。我们将使用该数据集的原始标签对请求中提到的主题进行多标签分类,并且我们还将设置一个标签问题,将请求的情感分类为“积极”、“中性”或“消极”。
settings = rg.Settings(
fields=[rg.TextField(name="text")],
questions=[
rg.MultiLabelQuestion(
name="topics",
title="Select the topic(s) of the request",
labels=data.info.features["label"].names,
visible_labels=10,
),
rg.LabelQuestion(
name="sentiment",
title="What is the sentiment of the message?",
labels=["positive", "neutral", "negative"],
),
],
)
dataset = rg.Dataset(
name="setfit_tutorial_dataset",
settings=settings,
)
dataset.create()
训练模型
现在,我们将使用从 HF 加载的数据以及为数据集配置的标签和问题,为数据集中的每个问题训练一个零样本文本分类模型。如前几节所述,我们将在两个分类器中都使用 SetFit 框架对 Sentence Transformers 进行少样本微调。此外,我们将使用的模型是 all-MiniLM-L6-v2,这是一个句子嵌入模型,它在一个包含 10 亿个句子对的数据集上使用对比目标进行了微调。
def train_model(question_name, template, multi_label=False):
train_dataset = get_templated_dataset(
candidate_labels=dataset.questions[question_name].labels,
sample_size=8,
template=template,
multi_label=multi_label,
)
# Train a model using the training dataset we just built
if multi_label:
model = SetFitModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2",
multi_target_strategy="one-vs-rest",
)
else:
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(model=model, train_dataset=train_dataset)
trainer.train()
return model
topic_model = train_model(
question_name="topics",
template="The customer request is about {}",
multi_label=True,
)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/topic_model"
# )
sentiment_model = train_model(question_name="sentiment", template="This message is {}", multi_label=False)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/sentiment_model"
# )
进行预测
训练步骤结束后,我们可以对数据进行预测。
def get_predictions(texts, model, question_name):
probas = model.predict_proba(texts, as_numpy=True)
labels = dataset.questions[question_name].labels
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
data = data.map(
lambda batch: {
"topics": list(get_predictions(batch["text"], topic_model, "topics")),
"sentiment": list(get_predictions(batch["text"], sentiment_model, "sentiment")),
},
batched=True,
)
data.to_pandas().head()
将记录写入 Argilla
利用我们生成的数据和预测结果,我们现在可以构建包含模型建议的记录 (标注团队将要标注的每个数据项)。对于 LabelQuestion
,我们将使用获得最高概率分数的标签;对于 MultiLabelQuestion
,我们将包含所有分数高于某个阈值的标签。在本例中,我们决定使用 2/len(labels)
,但您可以根据自己的数据进行实验,选择一个更严格或更宽松的阈值。
请注意,更宽松的阈值 (接近或等于
1/len(labels)
) 会建议更多标签,而更严格的阈值 (介于 2 和 3 之间) 则会选择更少 (或没有) 标签。
def add_suggestions(record):
suggestions = []
# Get label with max score for sentiment question
sentiment = max(record["sentiment"], key=lambda x: x["score"])["label"]
suggestions.append(rg.Suggestion(question_name="sentiment", value=sentiment))
# Get all labels above a threshold for topics questions
threshold = 2 / len(dataset.questions["topics"].labels)
topics = [label["label"] for label in record["topics"] if label["score"] >= threshold]
if topics:
suggestions.append(rg.Suggestion(question_name="topics", value=topics))
return suggestions
records = [rg.Record(fields={"text": record["text"]}, suggestions=add_suggestions(record)) for record in data]
一旦我们对结果满意,就可以将记录写入上面配置的数据集。现在您可以在 Argilla 中访问该数据集并查看建议。
dataset.records.log(records)
以下是 UI 显示我们模型建议的样子:
或者,您也可以将 Argilla 数据集保存并加载到 Hugging Face Hub 中。有关如何操作的更多信息,请参阅 Argilla 文档。
# Export to HuggingFace Hub
dataset.to_hub(repo_id="argilla/my_setfit_dataset")
# Import from HuggingFace Hub
dataset = rg.Dataset.from_hub(repo_id="argilla/my_setfit_dataset")
结论
在本教程中,我们介绍了如何使用 SetFit 库通过零样本方法向 Argilla 数据集添加建议。这将通过减少标注团队需要做出的决策和编辑次数来提高标注过程的效率。
查看以下链接获取更多资源:
< > 在 GitHub 上更新