使用 ModernBERT 刷新零样本分类
文本分类是机器学习领域的一项基础任务,拥有悠久的研究历史和深远的实际应用价值。此外,它也是许多实际项目中最重要的组成部分之一,从搜索引擎到生物医学研究无所不包。文本分类方法被用于科学文章分类、用户工单分类、社交媒体帖子情感分析、金融研究中的公司分类等。如果我们将任务推广到序列分类,那么该领域的应用数量和影响力将更大,从 DNA 序列分类到聊天机器人系统中用于保证高质量最新输出的最常用方式 RAG 管道,都属于此类。
自回归语言建模的最新进展为包括文本分类在内的许多零样本分类任务开辟了新领域。虽然这些模型提供了令人印象深刻的多功能性,但它们在严格遵守指令方面常常力不从心,并且在训练和推理方面计算效率低下。请查看我们旨在通过生成式语言模型实现零样本分类的项目:Knowledgator Unlimited Classifier
交叉编码器作为自然语言推理 (NLI) 模型是另一种流行的零样本分类器和检索增强生成 (RAG) 管道中常用的方法。该方法通过将待分类序列作为 NLI 前提,并从每个候选标签构建假设来工作。总体而言,由于其成对处理方法,这种方法在处理大量类别时面临效率挑战。此外,它理解交叉标签信息的能力有限,这会影响预测质量,尤其是在复杂场景中。
自 Science Word2Vec 工作以来,基于嵌入的方法被认为是文本分类的潜在方法之一,尤其是在零样本设置中。随着句子编码器对句子和文本语义理解的提升,使用句子嵌入进行文本分类的想法也应运而生。Sentence Transformers 的出现进一步提高了嵌入的质量,甚至无需微调即可将其用于分类任务。SetFit——一项基于句子转换器的工作,即使在每个标签只有少量示例的情况下也能实现良好的性能。尽管它们在许多语义任务中效率高且性能良好,但基于嵌入的方法在涉及逻辑和语义约束的复杂场景中常常表现不佳。
我们提出了一种新颖的文本分类方法,它基于 GLiNER 架构,并对其进行了专门的改编以用于序列分类任务。我们的方法旨在平衡更复杂模型的准确性和基于嵌入方法的效率,同时保持良好的零样本和少样本能力。
GLiClass 架构
我们的架构引入了一种新颖的序列分类方法,它能够在标签和输入文本之间实现丰富的交互,同时保持计算效率。该实现包含几个协同工作的关键阶段,以实现卓越的分类性能。
输入处理和标签整合
该过程从标签整合机制开始。我们为每个类别标签预置一个特殊标记 <
上下文表示学习
分词之后,合并后的输入 ID 通过双向 Transformer 架构(如 BERT 或 DeBERTa)进行处理。此阶段至关重要,因为它支持三种不同类型的上下文理解:
- 标签间交互:标签可以相互共享信息,使模型能够理解标签关系和层次结构。
- 文本到标签交互:输入文本可以直接影响标签表示。
- 标签到文本交互:标签信息可以指导文本的解释。
这种多向信息流比传统的交叉编码器架构具有显著优势,后者通常将交互限制在文本-标签对之间,并忽略了有价值的标签间关系。
表示池化
获得上下文表示后,我们对标签和文本采用单独的池化机制,从 Transformer 输出中提取基本信息。我们的实现支持多种池化策略:
- 第一个标记池化:利用初始标记的表示。
- 均值池化:对所有标记取平均值。
- 注意力加权池化:应用学习到的注意力权重。
- 自定义池化策略:根据特定的分类要求定制。
池化策略的选择可以根据分类任务的具体要求和输入数据的性质进行优化。
评分机制
最后阶段涉及计算池化表示之间的兼容性分数。我们通过一个灵活的评分框架来实现这一点,该框架可以适应各种方法:
** 简单点积评分:对于许多应用来说高效且有效 ** 神经网络评分:用于挑战性场景的更复杂的评分函数 ** 特定任务评分模块:根据特定分类要求定制
这种模块化的评分方法允许架构适应不同的分类场景,同时保持计算效率。
如何使用模型
我们在 Hugging Face 上开源了我们的模型。Modern GLiClass Collection
要使用它们,首先安装 gliclass 软件包
pip install gliclass
然后你需要初始化一个模型和一个管道
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
以下是如何执行推理
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
如何微调
首先,您需要准备以下格式的训练数据:
[
{"text": "Some text here!",
"all_labels": ["sport", "science", "business", …],
"true_labels": ["other"]}, …
]
下面是所需的导入要求:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score
import numpy as np
import random
from transformers import AutoTokenizer
import torch
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
from gliclass.training import TrainingArguments, Trainer
然后,我们初始化模型和分词器
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = 'knowledgator/gliclass-base-v1.0'
model = GLiClassModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
After, that we specify training arguments:
max_length = 1024
problem_type = "multi_label_classification"
architecture_type = model.config.architecture_type
prompt_first = model.config.prompt_first
training_args = TrainingArguments(
output_dir='models/test',
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type='linear',
warmup_ratio=0.0,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=8,
evaluation_strategy="epoch",
save_steps = 1000,
save_total_limit=10,
dataloader_num_workers=8,
logging_steps=10,
use_cpu = False,
report_to="none",
fp16=False,
)
当您以正确格式准备好数据集后,我们需要初始化 GLiClass 数据集和数据收集器。
train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)
test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)
data_collator = DataCollatorWithPadding(device=device)
一切准备就绪后,我们就可以开始训练了。
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
更多示例请参阅仓库:https://github.com/Knowledgator/GLiClass/blob/main/finetuning.ipynb
主要应用
GLiClass 在广泛的自然语言处理任务中表现出卓越的多功能性,这使其在研究和实际应用中都具有特别的价值。
多类分类
该架构可在一个处理运行中高效处理多达 100 个不同类别的大规模分类任务。此功能对于需要多个详细类别的应用程序尤其有价值,例如文档分类、产品分类和内容标记系统。
主题分类
GLiClass 擅长识别和分类文本的主题,非常适合:
- 学术论文分类
- 新闻文章分类
- 内容推荐系统
- 研究文档组织
情感分析
该架构能有效捕捉细微的情感和基于观点的语义内容,支持:
- 社交媒体情感追踪
- 客户反馈分析
- 产品评论分类
- 品牌认知度监测
事件分类
GLiClass 在识别和分类文本中的事件方面表现出强大的能力,支持:
- 新闻事件分类
- 社交媒体事件检测
- 历史事件分类
- 时间轴分析和组织
基于提示的约束分类
该系统提供灵活的基于提示的分类,并带有自定义约束,从而实现:
- 引导式分类任务
- 上下文感知分类
- 自定义分类规则
- 动态类别适应
自然语言推理
GLiClass 支持对文本关系进行复杂的推理,从而促进:
- 文本蕴含检测
- 矛盾识别
- 语义相似度评估
- 逻辑关系分析
检索增强生成 (RAG)
该架构良好的泛化能力以及对自然语言推理任务的支持,使其成为 RAG 管道中重新排序的良好选择。此外,GLiClass 的效率使其更具竞争力,尤其是在与交叉编码器相比时。
这一全面的应用范围使 GLiClass 成为应对现代自然语言处理挑战的多功能工具,在各种分类任务中提供灵活性和精确性。
基准测试结果
我们发布了基于 ModernBERT 的新 GLiClass 模型,与旧模型(如 DeBERTa)相比,它提供了更长的上下文长度支持(高达 8k 标记)和更快的推理速度,从而开辟了新的可能性。我们使用我们的架构和旧版架构对 GLiClass 模型进行了基准测试。
在下面,您可以看到在几个文本分类数据集上的 F1 分数。所有测试模型均未在这些数据集上进行微调,并以零样本设置进行测试。
模型 | IMDB | AG_NEWS | 情感 |
---|---|---|---|
gliclass-modern-large-v2.0-init (399 M) | 0.9137 | 0.7357 | 0.4140 |
gliclass-modern-base-v2.0-init (151 M) | 0.8264 | 0.6637 | 0.2985 |
gliclass-large-v1.0 (438 M) | 0.9404 | 0.7516 | 0.4874 |
gliclass-base-v1.0 (186 M) | 0.8650 | 0.6837 | 0.4749 |
gliclass-small-v1.0 (144 M) | 0.8650 | 0.6805 | 0.4664 |
Bart-large-mnli (407 M) | 0.89 | 0.6887 | 0.3765 |
Deberta-base-v3 (184 M) | 0.85 | 0.6455 | 0.5095 |
Comprehendo (184M) | 0.90 | 0.7982 | 0.5660 |
SetFit BAAI/bge-small-en-v1.5 (33.4M) | 0.86 | 0.5636 | 0.5754 |
在下面,您可以找到 ModernBERT GLiClass 与其他 GLiClass 模型的更全面比较
数据集 | gliclass-base-v1.0-init | gliclass-large-v1.0-init | gliclass-modern-base-v2.0-init | gliclass-modern-large-v2.0-init |
---|---|---|---|---|
CR | 0.8672 | 0.8024 | 0.9041 | 0.8980 |
sst2 | 0.8342 | 0.8734 | 0.9011 | 0.9434 |
sst5 | 0.2048 | 0.1638 | 0.1972 | 0.1123 |
20_新闻组 | 0.2317 | 0.4151 | 0.2448 | 0.2792 |
垃圾邮件 | 0.5963 | 0.5407 | 0.5074 | 0.6364 |
金融短语库 | 0.3594 | 0.3705 | 0.2537 | 0.2562 |
imdb | 0.8772 | 0.8836 | 0.8255 | 0.9137 |
ag_新闻 | 0.5614 | 0.7069 | 0.6050 | 0.6933 |
情感 | 0.2865 | 0.3840 | 0.2474 | 0.3746 |
cap_sotu | 0.3966 | 0.4353 | 0.2929 | 0.2919 |
烂番茄 | 0.6626 | 0.7933 | 0.6630 | 0.5928 |
平均值 | 0.5344 | 0.5790 | 0.5129 | 0.5447 |
我们研究了如果对模型进行微调,在每个标签的少量示例上,性能如何增长。此外,我们测试了一种简单的方法,即不提供真实的文本,而是提供给定文本主题的通用简短描述,我们称之为弱监督。令人惊讶的是,对于某些数据集,如“情感”,它显著提高了性能。
模型 | 示例数量 | sst5 | ag_新闻 | 情感 | 平均值 |
---|---|---|---|---|---|
gliclass-modern-large-v2.0-init | 0 | 0.1123 | 0.6933 | 0.3746 | 0.3934 |
gliclass-modern-large-v2.0-init | 8 | 0.5098 | 0.8339 | 0.5010 | 0.6149 |
gliclass-modern-large-v2.0-init | 弱监督 | 0.0951 | 0.6478 | 0.4520 | 0.3983 |
gliclass-modern-base-v2.0-init | 0 | 0.1972 | 0.6050 | 0.2474 | 0.3499 |
gliclass-modern-base-v2.0-init | 8 | 0.3604 | 0.7481 | 0.4420 | 0.5168 |
gliclass-modern-base-v2.0-init | 弱监督 | 0.1599 | 0.5713 | 0.3216 | 0.3509 |
gliclass-large-v1.0-init | 0 | 0.1639 | 0.7069 | 0.3840 | 0.4183 |
gliclass-large-v1.0-init | 8 | 0.4226 | 0.8415 | 0.4886 | 0.5842 |
gliclass-large-v1.0-init | 弱监督 | 0.1689 | 0.7051 | 0.4586 | 0.4442 |
gliclass-base-v1.0-init | 0 | 0.2048 | 0.5614 | 0.2865 | 0.3509 |
gliclass-base-v1.0-init | 8 | 0.2007 | 0.8359 | 0.4856 | 0.5074 |
gliclass-base-v1.0-init | 弱监督 | 0.0681 | 0.6627 | 0.3066 | 0.3458 |
结论
GLiClass 代表了文本分类领域的一项重大进展,它提供了一个强大而高效的解决方案,弥合了复杂基于 Transformer 的模型的准确性和基于嵌入的方法的简单性之间的鸿沟。通过利用一种新颖的架构,促进输入文本和标签之间的丰富交互,GLiClass 在零样本和少样本分类任务中实现了卓越的性能,同时保持了计算效率,即使在大型标签集的情况下也是如此。它捕捉交叉标签依赖关系、适应不同分类场景以及与现有 NLP 管道无缝集成的能力,使其成为各种应用程序的多功能工具,从情感分析和主题分类到检索增强生成和自然语言推理。