使用 Argilla 和 AutoTrain 对法律数据的 token 分类模型进行微调
我们都希望尝试使用现有工具/技术解决一些用例。在本教程中,我想回顾我在美国专利文本上微调模型的学习历程。
1. 简介
1.1 命名实体识别 (NER) 背景
命名实体识别 (NER) 是自然语言处理 (NLP) 中的一项基本任务,它涉及将文本中的命名实体识别并分类为预定义的类别,例如人名、组织、地点、医疗代码、时间表达式、数量、货币值等。
1.2 NER 在自然语言处理中的重要性
NER 在各种 NLP 应用中扮演着至关重要的角色,包括:
- 信息检索
- 问答系统
- 机器翻译
- 文本摘要
- 情感分析
1.3 特定领域或语言 NER 面临的挑战
虽然通用 NER 模型已经存在,但由于以下原因,它们在应用于专业领域或不常用语言时往往表现不佳:
- 领域特定术语
- 独特的实体类型
- 语言特定细微差别
1.4 对自定义微调模型的需求
为了解决这些挑战,微调自定义 NER 模型变得至关重要。这种方法允许:
- 适应特定领域:微调模型在特定任务或领域上的表现优于通用模型。
- 效率:微调模型通常需要更少的数据和计算资源即可在特定任务上获得良好性能。
- 更快的推理:较小的、特定任务的模型比大型通用模型运行得更快。
1.5 项目目标和概述
本项目旨在为 USPTO 专利微调自定义 NER 模型。我们的目标包括:
- 使用 Hugging Face Spaces 设置 Argilla 实例。
- 使用 Argilla UI 使用自定义标签注释我们的数据集。
- 使用 Hugging Face AutoTrain 创建一个在大小和推理速度方面更高效的模型。
- 展示迁移学习在 NER 任务中的有效性。
2. 数据背景
美国专利文本通常是关于发明的冗长描述性文档。本教程中使用的数据可通过 Kaggle USPTO 竞赛 获取。每项专利都包含以下字段:
- 标题
- 摘要
- 声明
- 描述
在本教程中,我们将重点关注 claims
字段。
2.1 问题陈述
我们的目标是微调一个模型,以对给定专利的 claims
字段中的 token 进行分类。
2.2 拆解问题
为实现这一目标,我们需要:
- 高质量数据以微调预训练的 token 分类模型
- 执行训练的基础设施
3. 使用 Argilla 创建高质量数据
Argilla 是一个出色的工具,可通过用户友好的界面进行标注,从而创建高质量数据集。
3.1 在 Hugging Face Spaces 上设置 Argilla
1. 访问 Hugging Face Spaces 部署页面
2. 创建一个新空间:
- 提供一个名称
- 选择
Docker
作为 Space SDK - 选择
Argilla
作为 Docker 模板 - 为简化起见,其他字段留空
- 点击
创建空间
3. 重启空间
现在您已在 Hugging Face Spaces 上运行 Argilla 实例。点击您创建的空间以进入 Argilla UI 的登录界面。使用以下凭据访问 UI:
- 用户名:
admin
- 密码:
12345678
[默认密码]
有关更多选项和为生产用例设置 Argilla 实例的信息,请参阅 在 Huggingface 上配置 Argilla
3.2 使用 Argilla Python SDK 创建数据集
步骤 1:安装和导入包
!pip install -U datasets argilla autotrain-advanced==0.8.8 > install_logs.txt 2>&1
import argilla as rg
import pandas as pd
import re
import os
import random
import torch
from IPython.display import Image, display,HTML
from datasets import load_dataset, Dataset, DatasetDict,ClassLabel,Sequence,Value,Features
from transformers import pipeline,TokenClassificationPipeline
from typing import List, Dict, Union,Tuple
from google.colab import userdata
步骤 2:初始化 Argilla 客户端
api_url:我们可以通过使用 https://huggingface.co/spaces/<hf_username>/<hf_space_name>?embed=true
获取此 URL
client = rg.Argilla(
api_url="https://bikashpatra-argilla-uspto-labelling.hf.space",
#api_url="https://<hf_username>-<hf_space_name>.hf.space # This is url to my public space.
api_key="admin.apikey", # default value. Shouldn't be used for production.
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
#Replace `<hf-username>` and `<space-name>` with your actual Hugging Face username and space name.
步骤 3:配置数据集
为了配置用于 token 分类任务的 Argilla 数据集,我们需要:
为我们的问题领域创建特定标签:我通过使用以下提示创建了一些标签:
建议一些可以用来注释美国专利权利要求或描述部分中 token 的标签,例如“过程”、“产品”、“物质组成”
我们需要配置数据集的字段/列和
questions
。questions
参数允许您指导/引导注释者完成任务。在我们的用例中,我们将使用我们为注释者创建的labels
来选择在注释文本片段(token)时使用。
# Labels for token classification
labels = [
"Process", "Product", "Composition of Matter", "Method of Use",
"Software", "Hardware", "Algorithm", "System", "Device",
"Apparatus", "Method", "Machine", "Manufacture", "Design",
"Pharmaceutical Formulation", "Biotechnology", "Chemical Compound",
"Electrical Circuit"
]
# Dataset settings
settings = rg.Settings(
guidelines="Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.",
fields=[
rg.TextField(name="tokens", title="Text", use_markdown=True),
rg.TextField(name="document_id", title="publication_number", use_markdown=True),
rg.TextField(name="sentence_id", title="sentence_id", use_markdown=False)
],
questions=[
rg.SpanQuestion(
name="span_label",
field="tokens",
labels=labels,
title="Classify the tokens according to the specified categories.",
allow_overlapping=True
)
]
)
步骤 4:在 Argilla 实例上创建数据集
完成设置后,我们就可以使用 rg.Dataset
API 创建数据集了。
# We name the dataset as claim_tokens
rg_dataset = rg.Dataset(
name="claim_tokens",
settings=settings,
)
rg_dataset.create()
/usr/local/lib/python3.10/dist-packages/argilla/datasets/_resource.py:202: UserWarning: Workspace not provided. Using default workspace: admin id: fd4fc24c-fc1f-4ffe-af41-d569432d6b50
warnings.warn(f"Workspace not provided. Using default workspace: {workspace.name} id: {workspace.id}")
Dataset(id=UUID('a187cdad-175e-4d87-989f-a529b9999bde') inserted_at=datetime.datetime(2024, 7, 28, 7, 23, 59, 902685) updated_at=datetime.datetime(2024, 7, 28, 7, 24, 1, 901701) name='claim_tokens' status='ready' guidelines='Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.' allow_extra_metadata=False workspace_id=UUID('fd4fc24c-fc1f-4ffe-af41-d569432d6b50') last_activity_at=datetime.datetime(2024, 7, 28, 7, 24, 1, 901701) url=None)
在步骤 4 之后,我们应该能在 Argilla UI 中看到创建的数据集。我们可以通过使用默认凭据登录 Argilla UI(URL 为 `https://huggingface.co/spaces/<hf-username>-<space-name>.hf.space)来验证。
我们可以通过点击数据集名称旁边的设置图标来查看数据集的设置。
def display_image(filename): display(Image(filename=filename))
display_image('/content/images/argilla_ds_list_settings.png')
设置屏幕的“字段”选项卡列出了我们在使用 Python SDK 创建数据集时配置的字段。
display_image('/content/images/argilla_ds_settings.png')
步骤 5:将记录插入 Argilla 数据集
数据准备笔记本可在此处找到:here
claims = pd.read_csv("/content/sample_publications.csv")
claims.head(2)
发布号 | 序列号 | 标记 | |
---|---|---|---|
0 | US-4444749-A | 0 | 一种洗发水,包含水溶液的... |
1 | US-4444749-A | 1 | 一种洗发水,包含水溶液的... |
这里我们正在读取 CSV 的行并将它们映射到我们在 Argilla 数据集配置步骤中创建的字段。
## We upload a csv with three columns : tokens, publication_number, sequence_id
publication_df = pd.read_csv("/content/sample_publications.csv")
## Convert dataframe rows to Argilla Records
records = [
rg.Record(
fields=
{"tokens": "".join(row["tokens"])
,'document_id':str(row['publication_number'])
,'sentence_id':str(row['sequence_id'])
})
for _,row in publication_df.iterrows()
]
## Store Argilla records to Argilla Dataset
rg_dataset.records.log(records)
DatasetRecords: The provided batch size 256 was normalized. Using value 149.
Sending records...: 100%|██████████| 1/1 [00:00<00:00, 1.71batch/s]
DatasetRecords(Dataset(id=UUID('a187cdad-175e-4d87-989f-a529b9999bde') inserted_at=datetime.datetime(2024, 7, 28, 7, 23, 59, 902685) updated_at=datetime.datetime(2024, 7, 28, 7, 24, 1, 901701) name='claim_tokens' status='ready' guidelines='Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.' allow_extra_metadata=False workspace_id=UUID('fd4fc24c-fc1f-4ffe-af41-d569432d6b50') last_activity_at=datetime.datetime(2024, 7, 28, 7, 24, 1, 901701) url=None))
一旦我们将记录推送到 Argilla 数据集,UI 将渲染记录和标签,供注释者注释文本。
请查看下面的屏幕截图。
display_image("/content/images/annotation_screen.png")
步骤 6:使用适当的标签注释每条记录中的 token。
登录 Argilla UI 并开始注释。
Argilla UI:https://huggingface.co/spaces/<hf-username>/<hf-space-name>
用户名:admin
密码:12345678
注释数据后,我们需要将 Argilla 数据集转换为 HuggingFace 数据集,以便使用 HuggingFace AutoTrain 进行模型微调。HF AutoTrain 也允许在 CSV 数据上进行训练,可以通过 AutoTrain UI 上传。但对于本教程,我们将使用 Huggingface 数据集。
4. Argilla 数据集到 HuggingFace 数据集
步骤 1:加载我们已注释的数据集
rg_dataset = client.datasets("claim_tokens")
步骤 2:过滤已注释的行/记录。
为了快速迭代注释和训练,我们应该能够注释少量记录并训练我们的模型。我们可以通过使用 Argilla 数据集的查询/过滤功能来实现。
使用 rg.Query()
API,我们可以过滤已注释的记录,以准备我们的训练数据集。
status_filter = rg.Query(filter=rg.Filter(("response.status", "==", "submitted")))
submitted = rg_dataset.records(status_filter).to_list(flatten=True)
submitted[0]
{'id': '01e9b4bb-9c98-4cec-acea-dd686cddf5f0',
'status': 'pending',
'_server_id': '0b6f16f3-c3dc-4947-ac77-8b65002bf350',
'tokens': 'The FINFET of claim 11 , wherein the conformal gate dielectric comprises a high-κ gate dielectric selected from the group consisting of: hafnium oxide (HfO 2 ), lanthanum oxide (La 2 O 3 ), and combinations thereof.',
'document_id': 'US-11631617-B2',
'sentence_id': '14',
'span_label.responses': [[{'label': 'Electrical Circuit',
'start': 4,
'end': 10},
{'label': 'Chemical Compound', 'start': 138, 'end': 151},
{'label': 'Chemical Compound', 'start': 162, 'end': 177}]],
'span_label.responses.users': ['4e9588d6-e2d6-450d-82c6-b33324d94708'],
'span_label.responses.status': ['submitted']}
注释过的数据集不能直接用于模型微调。对于 token 分类任务,我们必须使数据符合以下结构。
- 数据集结构:数据集通常应包含两个主要列
tokens
:每个示例的单词/token 列表。ner_tags
:每个 token 对应的标签列表。标签必须遵循 IOB 标注方案。
- 标签编码:标签应该是整数,每个整数对应一个特定的命名实体标签。以下函数将使我们能够将 Argilla 数据集转换为所需的数据集结构。
def get_iob_tag_for_token(token_start:int, token_end:int, ner_spans:List[Dict[str, Union[int, str]]]) -> str:
"""
Determine the IOB tag for a given token based on its position within NER spans.
Args:
token_start (int): The start index of the token in the text.
token_end (int): The end index of the token in the text.
ner_spans (List[Dict[str, Union[int, str]]]): A list of dictionaries containing NER span information.
Each dictionary should have 'start', 'end', and 'label' keys.
Returns:
str: The IOB tag for the token. 'B-' prefix for the beginning of an entity,
'I-' for inside an entity, or 'O' for outside any entity.
"""
for span in ner_spans:
if token_start >= span["start"] and token_end <= span["end"]:
if token_start == span["start"]:
return f"B-{span['label']}"
else:
return f"I-{span['label']}"
return "O"
def extract_ner_tags(text:str, responses:List[Dict[str, Union[int, str]]]):
"""
Extract NER tags for tokens in the given text based on the provided NER responses.
Args:
text (str): The input text to be tokenized and tagged.
responses (List[Dict[str, Union[int, str]]]): A list of dictionaries containing NER span information.
Each dictionary should have 'start', 'end', and 'label' keys.
Returns:
List[str]: A list of IOB tags corresponding to each non-whitespace token in the text.
"""
tokens = re.split(r"(\s+)", text)
ner_tags = []
current_position = 0
for token in tokens:
if token.strip():
token_start = current_position
token_end = current_position + len(token)
tag = get_iob_tag_for_token(token_start, token_end, responses)
ner_tags.append(tag)
current_position += len(token)
return ner_tags
步骤 3:获取 token 及其各自的注释
def get_tokens_ner_tags(annotated_rows) -> Tuple[List[List[str]], List[List[str]]]:
"""
Extract tokens and their corresponding NER tags from annotated rows.
This function processes a list of annotated rows, where each row contains
tokens and span labels. It splits the tokens and extracts NER tags for each token.
Args:
annotated_rows (List[Dict[str, Union[str, List[Dict[str, Union[int, str]]]]]]):
A list of dictionaries, where each dictionary represents an annotated row.
Each row should have a 'tokens' key (str) and a 'span_label.responses' key
(List[Dict[str, Union[int, str]]]).
Returns:
Tuple[List[List[str]], List[List[str]]]: A tuple containing two elements:
1. A list of token lists, where each inner list represents tokens for a row.
2. A list of NER tag lists, where each inner list represents NER tags for a row.
"""
tokens = []
ner_tags = []
for idx,row in enumerate(annotated_rows):
tags = extract_ner_tags(row["tokens"], row["span_label.responses"][0])
tks = row["tokens"].split()
tokens.append(tks)
ner_tags.append(tags)
return tokens, ner_tags
train_tokens, train_ner_tags = get_tokens_ner_tags(submitted[:1])
validation_tokens, validation_ner_tags = get_tokens_ner_tags(submitted[1:2])
心情检查
在几次操作后检查我们的数据总是好的。这将帮助我们理解和调试,如果每一步的输出都产生了预期的输出。
display(HTML('''
<style>
pre {
white-space: pre-wrap;
word-wrap: break-word;
}
.colored-header {
color: blue; /* Change 'blue' to any color you prefer */
font-size: 16px;
margin-bottom: 8px;
}
</style>
'''))
display(HTML("<pre><span class='colored-header'>Sample Train Tokens:</span>" +
f"{train_tokens[0]}</pre><br>"))
display(HTML("<pre><span class='colored-header'>Sample Valid Tokens:</span>" +
f"{validation_tokens[0]}</pre><br>"))
display(HTML("<pre><span class='colored-header'>Sample Train tags:</span>" +
f"{train_ner_tags[0]}</pre><br>"))
display(HTML("<pre><span class='colored-header'>Sample Valid tags:</span>" +
f"{validation_ner_tags[0]}</pre>"))
Sample Train Tokens:['The', 'FINFET', 'of', 'claim', '11', ',', 'wherein', 'the', 'conformal', 'gate', 'dielectric', 'comprises', 'a', 'high-κ', 'gate', 'dielectric', 'selected', 'from', 'the', 'group', 'consisting', 'of:', 'hafnium', 'oxide', '(HfO', '2', '),', 'lanthanum', 'oxide', '(La', '2', 'O', '3', '),', 'and', 'combinations', 'thereof.']
Sample Valid Tokens:['The', 'method', 'of', 'claim', '2', ',', 'wherein', 'generating', 'the', 'one', 'or', 'more', 'possible', 'design', 'modification', 'solutions', 'based', 'at', 'least', 'in', 'part', 'on', 'the', 'set', 'of', 'attack', 'mitigation', 'rules', 'comprises', 'generating', 'the', 'one', 'or', 'more', 'possible', 'design', 'modification', 'solutions', 'by', 'inputting', 'the', 'set', 'of', 'attack', 'mitigation', 'rules', 'to', 'a', 'model', 'configured', 'to', 'perform', 'structural', 'and', 'functional', 'analysis', 'to', 'interpret', 'the', 'set', 'of', 'attack', 'mitigation', 'rules,', 'wherein', 'the', 'set', 'of', 'attack', 'mitigation', 'rules', 'comprises', 'one', 'or', 'more', 'rules', 'used', 'by', 'the', 'model', 'to', 'identify', 'the', 'key-gate', 'type', 'for', 'each', 'possible', 'design', 'modification', 'solution', 'of', 'the', 'one', 'or', 'more', 'possible', 'design', 'modification', 'solutions', 'and', 'one', 'or', 'more', 'rules', 'used', 'by', 'the', 'model', 'to', 'identify', 'the', 'location', 'where', 'to', 'insert', 'the', 'key-gate', 'type', 'for', 'each', 'possible', 'design', 'modification', 'solution', 'of', 'the', 'one', 'or', 'more', 'possible', 'design', 'modification', 'solutions.']
Sample Train tags:['O', 'B-Electrical Circuit', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Chemical Compound', 'I-Chemical Compound', 'O', 'O', 'O', 'B-Chemical Compound', 'I-Chemical Compound', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Sample Valid tags:['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Process', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Process', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Algorithm', 'I-Algorithm', 'I-Algorithm', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Biotechnology', 'I-Biotechnology', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Process', 'I-Process', 'O', 'O']
由于我们正在尝试使数据创建和模型训练管道正常运行,为了简单起见,我只处理一个记录用于训练和验证。
步骤 4:将标签(标签)映射到整数
def mapped_ner_tags(ner_tags: List[List[str]]) -> List[List[int]]:
"""
Convert a list of NER tags to their corresponding integer IDs.
This function takes a list of lists containing string NER tags, creates a unique mapping
of these tags to integer IDs, and then converts all tags to their respective IDs.
Args:
ner_tags (List[List[str]]): A list of lists, where each inner list contains string NER tags.
Returns:
List[List[int]]: A list of lists, where each inner list contains integer IDs
corresponding to the input NER tags.
Example:
>>> ner_tags = [['O', 'B-PER', 'I-PER'], ['O', 'B-ORG']]
>>> mapped_ner_tags(ner_tags)
[[0, 1, 2], [0, 3]]
Note:
The mapping of tags to IDs is created based on the unique tags present in the input.
The order of ID assignment may vary between function calls if the input changes.
"""
labels = list(set([item for sublist in ner_tags for item in sublist]))
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: id_ for id_, label in id2label.items()}
mapped_ner_tags = [[label2id[label] for label in ner_tag] for ner_tag in ner_tags]
return mapped_ner_tags
def get_labels(ner_tags: List[List[str]]) -> List[str]:
"""
Extract unique labels from a list of NER tag sequences.
This function takes a list of lists containing NER tags and returns a list of unique labels
found across all sequences.
Args:
ner_tags (List[List[str]]): A list of lists, where each inner list contains string NER tags.
Returns:
List[str]: A list of unique NER labels found in the input sequences.
Example:
>>> ner_tags = [['O', 'B-PER', 'I-PER'], ['O', 'B-ORG', 'I-ORG'], ['O', 'B-PER']]
>>> get_labels(ner_tags)
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
Note:
The order of labels in the output list is not guaranteed to be consistent
between function calls, as it depends on the order of iteration over the set.
"""
return list(set([item for sublist in ner_tags for item in sublist]))
步骤 5:Argilla 数据集到 HuggingFace 数据集
现在我们的数据结构已符合 token 分类数据集的要求。我们只需创建一个 Hugging Face 数据集。
train_labels = get_labels(train_ner_tags)
validation_labels = get_labels(validation_ner_tags)
labels = list(set(train_labels + validation_labels))
features = Features({
"tokens": Sequence(Value("string")),
"ner_tags": Sequence(ClassLabel(num_classes=len(labels), names=labels))
})
train_records = [
{
"tokens": token,
"ner_tags": ner_tag,
}
for token, ner_tag in zip(train_tokens, mapped_ner_tags(train_ner_tags))
]
validation_records = [
{
"tokens": token,
"ner_tags": ner_tag,
}
for token, ner_tag in zip(validation_tokens, mapped_ner_tags(validation_ner_tags))
]
span_dataset = DatasetDict(
{
"train": Dataset.from_list(train_records,features=features),
"validation": Dataset.from_list(validation_records,features=features),
}
)
# assertion to verify if train split conforms the dataset structure required for fine-tuning.
assert span_dataset['train'].features['ner_tags'].feature.names is not None
步骤 6:将数据集推送到 Hugginface Hub
!huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful
span_dataset.push_to_hub("bikashpatra/sample_claims_annotated_hf")
Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format: 0%| | 0/1 [00:00<?, ?ba/s]
Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format: 0%| | 0/1 [00:00<?, ?ba/s]
CommitInfo(commit_url='https://huggingface.co/datasets/bikashpatra/sample_claims_annotated_hf/commit/e9faaa35dda423fcb2bccde9f19cbacd832af80a', commit_message='Upload dataset', commit_description='', oid='e9faaa35dda423fcb2bccde9f19cbacd832af80a', pr_url=None, pr_revision=None, pr_num=None)
5. 使用 AutoTrain 进行模型微调
Huggingface AutoTrain 是一个无需编写任何代码即可训练模型的简单工具。我们可以使用 autotrain 对 token 分类、文本生成、图像分类等各种任务进行微调。为了使用 AutoTrain,我们首先需要在 HF space 中创建一个 AutoTrain 实例。使用 创建空间 链接。对于 space SDK,选择 Docker,然后选择 AutoTrain 作为 Docker 模板。我们需要选择一个硬件来训练我们的模型。请查看屏幕截图以供快速参考
display_image("/content/images/autotrain_screen1.png")
display_image("/content/images/autotrain_screen2.png")
5.1 使用 AutoTrain UI
创建空间后,AutoTrain UI 将允许我们从一系列任务中选择。我们必须在 AutoTrain UI 上配置我们的训练器。
- 我们将选择 Token 分类作为我们的任务。
- 在本教程中,我们将微调
google-bert/bert-base-uncased
。我们可以从列表中选择任何模型。 - 对于数据源,选择
Hugging Face Hub
,它将提供一个文本框,用于填写我们想要用于微调的数据集。我们将使用我们推送到 Huggingface hub 的数据集。我将使用我推送到 huggingface hub 的数据集bikashpatra/claims_annotated_hf
- 输入
train
和validation
分割的键。 - 在列映射下,输入存储 token 和标签的列。在我的数据集中,token 存储在
tokens
列中,标签存储在ner_tags
列中。通过上述 5 个输入,我们可以触发Start Training
,AutoTrain 将负责微调我们数据集上的基础模型。
display_image("/content/images/autotrain_ui.png")
5.2 使用 AutoTrain CLI
# for this cell to work, you will have to store HF_TOKEN as secret in colab notebook.
os.environ['TOKEN'] = userdata.get('HF_TOKEN')
!autotrain token-classification --train \
--username "bikashpatra" \
--token $TOKEN \
--backend "spaces-a10g-small" \
--project-name "claims-token-classification" \
--data-path "bikashpatra/sample_claims_annotated_hf" \
--train-split "train" \
--valid-split "validation" \
--tokens-column "tokens" \
--tags-column "ner_tags" \
--model "distilbert-base-uncased" \
--lr "2e-5" \
--log "tensorboard" \
--epochs "10" \
--weight-decay "0.01" \
--warmup-ratio "0.1" \
--max-seq-length "256" \
--mixed-precision "fp16" \
--push-to-hub
[1mINFO [0m | [32m2024-08-20 06:44:18[0m | [36mautotrain.cli.run_token_classification[0m:[36mrun[0m:[36m179[0m - [1mRunning Token Classification[0m
[33m[1mWARNING [0m | [32m2024-08-20 06:44:18[0m | [36mautotrain.trainers.common[0m:[36m__init__[0m:[36m180[0m - [33m[1mParameters supplied but not used: version, inference, config, func, train, deploy, backend[0m
[1mINFO [0m | [32m2024-08-20 06:44:22[0m | [36mautotrain.cli.run_token_classification[0m:[36mrun[0m:[36m185[0m - [1mJob ID: bikashpatra/autotrain-claims-token-classification[0m
AutoTrain 会自动为我们创建 huggingface space 并触发训练任务。创建的空间链接是 `https://huggingface.co/spaces/$JOBID`,其中 JOBID 是我们从 autotrain cli 命令日志中获得的值。
如果模型训练成功执行,我们的模型将以我们提供给 --project-name
的值提供。在上面的例子中,它是 claims-token-classification
6. 推理
经过所有辛勤工作,我们已经在自定义数据集上训练了我们的模型。我们可以使用我们训练好的模型来预测未注释行的标签。我们将使用 HF Pipelines API。Pipelines 是易于使用的抽象,用于加载模型并对未见过的数据执行推理。在本教程的上下文中,“对未见过文本的推理”意味着预测未注释文本中 token 的标签。
# Classify a sample text
claims_text = """
The FINFET of claim 11 , wherein the conformal gate dielectric comprises a high-κ gate dielectric selected from
the group consisting of: hafnium oxide (HfO 2 ), lanthanum oxide (La 2 O 3 ), and combinations thereof.
"""
classifier = pipeline("token-classification", model="bikashpatra/claims-token-classification",device="cpu")
preds = classifier(claims_text)
# The labels used for fine-tuning the model.
classifier.model.config.id2label
{0: 'B-Chemical Compound',
1: 'I-Biotechnology',
2: 'B-Electrical Circuit',
3: 'B-Process',
4: 'B-Biotechnology',
5: 'O',
6: 'I-Chemical Compound',
7: 'I-Process',
8: 'B-Algorithm',
9: 'I-Algorithm'}
7. 将预测结果推送到 Argilla 数据集
使用 rg.Query
API,我们过滤未注释数据并预测 token。
过滤器 rg.Filter(("response.status","==","pending"))
允许我们创建一个 Argilla 过滤器,我们将其传递给 rg.Query
,以获取 Argilla 数据集中所有未注释的记录。
# Create a filter query to get only `pending` records in argilla dataset.
status_filter = rg.Query(filter=rg.Filter(("response.status", "==", "pending")))
submitted = rg_dataset.records(status_filter).to_list(flatten=True)
claims = random.sample(submitted,k=10) # pick 10 random samples.
spans = classifier(claims[0]['tokens'])
7.1 预测 span 的辅助函数
def predict_spanmarker(pipe:TokenClassificationPipeline,text: str):
"""
Predict span markers for the given text using the provided pipeline.
Args:
pipe (TokenClassificationPipeline): A pipeline object for token classification.
text (str): The input text for which span markers are to be predicted.
Returns:
List[Dict[str, Union[int, str]]]: A list of dictionaries containing the predicted span markers.
Each dictionary should have 'start', 'end', and 'label' keys.
"""
markers = pipe(text)
spans = [
{"label": marker["entity"][2:], "start": marker["start"], "end": marker["end"]}
for marker in markers if marker["entity"] != "O"
]
return spans
updated_data=[
{
"span_label": predict_spanmarker(pipe=classifier, text=sample['tokens']),
"id": sample["id"],
}
for sample in claims
]
# print a few predictions
updated_data[0]['span_label'][:2]
[{'label': 'Chemical Compound', 'start': 0, 'end': 3},
{'label': 'Process', 'start': 4, 'end': 10}]
7.2 将记录插入 Argilla 数据集。
rg_dataset.records.log(records=updated_data)
DatasetRecords: The provided batch size 256 was normalized. Using value 10.
Sending records...: 100%|██████████| 1/1 [00:00<00:00, 1.15batch/s]
DatasetRecords(Dataset(id=UUID('a187cdad-175e-4d87-989f-a529b9999bde') inserted_at=datetime.datetime(2024, 7, 28, 7, 23, 59, 902685) updated_at=datetime.datetime(2024, 7, 28, 7, 35, 55, 80617) name='claim_tokens' status='ready' guidelines='Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.' allow_extra_metadata=False distribution=None workspace_id=UUID('fd4fc24c-fc1f-4ffe-af41-d569432d6b50') last_activity_at=datetime.datetime(2024, 7, 28, 7, 35, 55, 80181)))
我们在此处更新的记录存储为 suggestions
,而不是 responses
。在本教程的上下文中,当注释者保存注释时会创建响应。建议是模型预测的标签。因此,我们在此处更新的记录将具有 response.status
作为 pending
,而不是 submitted
。这将允许我们/注释者检查预测的标签并接受或拒绝模型预测。
如果我们要接受模型对文本中 token 的预测注释,我们可以将 [suggestions
] 保存为 [responses
],否则我们需要添加/删除/编辑应用于 token 的标签。
8. 结论
在本全面的教程中,我们探索了数据注释和模型微调的完整工作流程。我们首先在 Hugging Face Spaces 上设置了一个 Argilla 实例,提供了一个强大的数据管理平台。然后,我们配置并在 Argilla 实例中创建了一个数据集,利用其用户友好的界面手动注释了部分记录。
我们继续将高质量的注释数据导出到 Hugging Face 数据集,弥合了注释和模型训练之间的鸿沟。然后,我们展示了迁移学习的强大功能,使用 Hugging Face 的 AutoTrain(一个简化模型训练复杂性的工具)对这个精心策划的数据集上的 distilbert-base-uncased
模型进行了微调。
工作流程最终形成了一个闭环,我们将微调后的模型应用于 Argilla 数据集中剩余的未标记记录进行注释,展示了机器学习如何加速注释过程。本教程为实现迭代注释和微调管道奠定了坚实的基础,同时说明了人类专业知识和机器学习能力之间的协同作用。
这种迭代方法允许持续改进,使其成为高效、有效地处理各种自然语言处理任务的宝贵工具。
9. 致谢
我要向以下为本笔记本做出贡献的个人表示衷心的感谢:
- David Berenstein 提供了宝贵的见解和指导。
- Sara Han 回答了我 Discord 上频繁的疑问。
没有他们的支持和专业知识,这项工作是不可能完成的。
此外,将 https://github.com/bikash119/argilla/blob/argilla_with_autotrain/argilla/docs/community/token_classification_tutorial.ipynb
中的 github 替换为 nbsanity 即可看到本笔记本的更漂亮版本。感谢 Hamel Hussain 创建了这个笔记本渲染工具。