开源 AI 食谱文档

使用 Cleanlab 检测文本数据集中的问题

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

使用 Cleanlab 检测文本数据集中的问题

作者:Aravind Putrevu

在这个 5 分钟的快速入门教程中,我们将使用 Cleanlab 来检测一个意图分类数据集中的各种问题。该数据集由一家网上银行的 (文本) 客户服务请求组成。我们考虑了 Banking77-OOS 数据集 的一个子集,其中包含 1000 个客户服务请求,这些请求根据其意图被分为 10 个类别 (您可以在任何文本分类数据集上运行相同的代码)。Cleanlab 能够自动识别我们数据集中的不良样本,包括错误标注的数据、超出范围的样本 (离群点) 或其他模糊的样本。在深入进行数据建模之前,请考虑过滤或修正这些不良样本!

本教程内容概览:

  • 使用预训练的 Transformer 模型从客户服务请求中提取文本嵌入

  • 在文本嵌入上训练一个简单的逻辑回归模型,以计算样本外 (out-of-sample) 预测概率

  • 使用这些预测和嵌入运行 Cleanlab 的 `Datalab` 审计,以识别数据集中的问题,如:标签问题、离群点和近似重复项。

快速入门

已经有了在现有标签集上训练的模型得到的 (样本外) `pred_probs` (预测概率)?或许你还有一些数值 `features` (特征)?运行下面的代码来查找数据集中任何潜在的标签错误。

注意: 如果在 Colab 上运行,可能需要使用 GPU (选择:运行时 > 更改运行时类型 > 硬件加速器 > GPU)

from cleanlab import Datalab

lab = Datalab(data=your_dataset, label_name="column_name_of_labels")
lab.find_issues(pred_probs=your_pred_probs, features=your_features)

lab.report()
lab.get_issues()

安装所需的依赖项

您可以使用 `pip` 安装本教程所需的所有软件包,如下所示

!pip install -U scikit-learn sentence-transformers datasets
!pip install -U "cleanlab[datalab]"
import re
import string
import pandas as pd
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import cross_val_predict
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer

from cleanlab import Datalab
import random
import numpy as np

pd.set_option("display.max_colwidth", None)

SEED = 123456  # for reproducibility
np.random.seed(SEED)
random.seed(SEED)

加载并格式化文本数据集

from datasets import load_dataset

dataset = load_dataset("PolyAI/banking77", split="train")
data = pd.DataFrame(dataset[:1000])
data.head()
>>> raw_texts, labels = data["text"].values, data["label"].values
>>> num_classes = len(set(labels))

>>> print(f"This dataset has {num_classes} classes.")
>>> print(f"Classes: {set(labels)}")
This dataset has 7 classes.
Classes: {32, 34, 36, 11, 13, 46, 17}

让我们查看数据集中的第 i 个示例

>>> i = 1  # change this to view other examples from the dataset
>>> print(f"Example Label: {labels[i]}")
>>> print(f"Example Text: {raw_texts[i]}")
Example Label: 11
Example Text: What can I do if my card still hasn't arrived after 2 weeks?

数据存储为两个 numpy 数组

  1. `raw_texts` 以文本格式存储客户服务请求的话语
  2. `labels` 存储每个示例的意图类别 (标签)
使用您自己的数据 (BYOD)?

您可以轻松地将上面的数据替换为您自己的文本数据集,并继续本教程的其余部分。

接下来,我们将文本字符串转换为更适合作为我们机器学习模型输入的向量。

我们将使用来自预训练 Transformer 模型的数值表示作为我们文本的嵌入。 Sentence Transformers 库提供了计算文本数据嵌入的简单方法。这里,我们加载预训练的 `electra-small-discriminator` 模型,然后通过网络运行我们的数据来提取每个示例的向量嵌入。

transformer = SentenceTransformer("google/electra-small-discriminator")
text_embeddings = transformer.encode(raw_texts)

我们后续的机器学习模型将直接操作 `text_embeddings` 的元素,以便对客户服务请求进行分类。

定义分类模型并计算样本外预测概率

利用预训练网络进行特定分类任务的典型方法是添加一个线性输出层,并在新数据上微调网络参数。然而,这可能计算量很大。或者,我们可以冻结网络的预训练权重,只训练输出层,而无需依赖 GPU。这里我们通过在提取的嵌入之上拟合一个 scikit-learn 线性模型来方便地实现这一点。

为了识别标签问题,cleanlab 需要模型对每个数据点进行概率预测。然而,对于模型之前训练过的数据点,这些预测会 *过拟合* (因此不可靠)。cleanlab 仅用于处理 **样本外** (out-of-sample) 的预测类别概率,即在模型训练期间未见过的数据点。

这里我们使用逻辑回归模型和交叉验证,为数据集中的每个示例获取样本外预测类别概率。请确保 `pred_probs` 的列相对于类别顺序正确排列,对于 Datalab 来说,这是按类别名称的字典序排序的。

model = LogisticRegression(max_iter=400)

pred_probs = cross_val_predict(model, text_embeddings, labels, method="predict_proba")

使用 Cleanlab 发现数据集中的问题

给定特征嵌入和从任何模型获得的 (样本外) 预测类别概率,cleanlab 可以快速帮助您识别数据集中的低质量样本。

在这里,我们使用 Cleanlab 的 `Datalab` 来发现数据中的问题。Datalab 提供了多种加载数据的方式;我们将简单地将训练特征和带噪声的标签包装在一个字典中。

data_dict = {"texts": raw_texts, "labels": labels}

要审计您的数据,只需调用 `find_issues()`。我们传入上面获得的预测概率和特征嵌入,但您不一定需要提供所有这些信息,这取决于您感兴趣的问题类型。您提供的输入越多,`Datalab` 就能在您的数据中检测到更多类型的问题。使用更好的模型来生成这些输入将确保 cleanlab 更准确地评估问题。

lab = Datalab(data_dict, label_name="labels")
lab.find_issues(pred_probs=pred_probs, features=text_embeddings)

输出会是这样的

Finding null issues ...
Finding label issues ...
Finding outlier issues ...
Fitting OOD estimator based on provided features ...
Finding near_duplicate issues ...
Finding non_iid issues ...
Finding class_imbalance issues ...
Finding underperforming_group issues ...

Audit complete. 62 issues found in the dataset.

审计完成后,使用 `report` 方法查看结果

>>> lab.report()
Here is a summary of the different kinds of issues found in the data:

    issue_type  num_issues
       outlier          37
near_duplicate          14
         label          10
       non_iid           1

Dataset Information: num_examples: 1000, num_classes: 7


---------------------- outlier issues ----------------------

About this issue:
	Examples that are very different from the rest of the dataset 
    (i.e. potentially out-of-distribution or rare/anomalous instances).
    

Number of examples with this issue: 37
Overall dataset quality in terms of this issue: 0.3671

Examples representing most severe instances of this issue:
     is_outlier_issue  outlier_score
791              True       0.024866
601              True       0.031162
863              True       0.060738
355              True       0.064199
157              True       0.065075


------------------ near_duplicate issues -------------------

About this issue:
	A (near) duplicate issue refers to two or more examples in
    a dataset that are extremely similar to each other, relative
    to the rest of the dataset.  The examples flagged with this issue
    may be exactly duplicated, or lie atypically close together when
    represented as vectors (i.e. feature embeddings).
    

Number of examples with this issue: 14
Overall dataset quality in terms of this issue: 0.5961

Examples representing most severe instances of this issue:
     is_near_duplicate_issue  near_duplicate_score near_duplicate_sets  distance_to_nearest_neighbor
459                     True              0.009544               [429]                      0.000566
429                     True              0.009544               [459]                      0.000566
501                     True              0.046044          [412, 517]                      0.002781
412                     True              0.046044               [501]                      0.002781
698                     True              0.054626               [607]                      0.003314


----------------------- label issues -----------------------

About this issue:
	Examples whose given label is estimated to be potentially incorrect
    (e.g. due to annotation error) are flagged as having label issues.
    

Number of examples with this issue: 10
Overall dataset quality in terms of this issue: 0.9930

Examples representing most severe instances of this issue:
     is_label_issue  label_score  given_label  predicted_label
379           False     0.025486           32               11
100           False     0.032102           11               36
300           False     0.037742           32               46
485            True     0.057666           17               34
159            True     0.059408           13               11


---------------------- non_iid issues ----------------------

About this issue:
	Whether the dataset exhibits statistically significant
    violations of the IID assumption like:
    changepoints or shift, drift, autocorrelation, etc.
    The specific violation considered is whether the
    examples are ordered such that almost adjacent examples
    tend to have more similar feature values.
    

Number of examples with this issue: 1
Overall dataset quality in terms of this issue: 0.0000

Examples representing most severe instances of this issue:
     is_non_iid_issue  non_iid_score
988              True       0.563774
975             False       0.570179
997             False       0.571891
967             False       0.572357
956             False       0.577413

Additional Information: 
p-value: 0.0

标签问题

报告显示,cleanlab 在我们的数据集中识别出了许多标签问题。我们可以使用 `get_issues` 方法查看哪些样本被标记为可能错误标注,以及每个样本的标签质量得分,并将 `label` 指定为参数以专注于数据中的标签问题。

label_issues = lab.get_issues("label")
label_issues.head()
is_label_issue (是否是标签问题) label_score (标签得分) given_label (给定标签) predicted_label (预测标签)
0 否 (False) 0.903926 11 11
1 否 (False) 0.860544 11 11
2 否 (False) 0.658309 11 11
3 否 (False) 0.697085 11 11
4 否 (False) 0.434934 11 11

此方法返回一个包含每个样本标签质量得分的数据帧。这些数值分数介于 0 和 1 之间,分数越低表示样本越有可能被错误标注。该数据帧还包含一个布尔列,指明每个样本是否被识别为存在标签问题 (表明它可能被错误标注)。

我们可以获取被标记为有标签问题的样本子集,并按标签质量得分排序,以找出数据集中最可能被错误标注的 5 个样本的索引。

>>> identified_label_issues = label_issues[label_issues["is_label_issue"] == True]
>>> lowest_quality_labels = label_issues["label_score"].argsort()[:5].to_numpy()

>>> print(
...     f"cleanlab found {len(identified_label_issues)} potential label errors in the dataset.\n"
...     f"Here are indices of the top 5 most likely errors: \n {lowest_quality_labels}"
... )
cleanlab found 10 potential label errors in the dataset.
Here are indices of the top 5 most likely errors: 
 [379 100 300 485 159]

让我们回顾一些最可能的标签错误。

在这里,我们展示了被识别为数据集中最可能的 5 个标签错误的样本,以及它们给定的 (原始) 标签和 cleanlab 建议的替代标签。

data_with_suggested_labels = pd.DataFrame(
    {"text": raw_texts, "given_label": labels, "suggested_label": label_issues["predicted_label"]}
)
data_with_suggested_labels.iloc[lowest_quality_labels]

上述命令的输出如下

text (文本) given_label (给定标签) suggested_label (建议标签)
379 我计划进行的转账所用的汇率有特定的来源吗? 32 11
100 你能分享卡的追踪号码吗? 11 36
300 如果我需要兑现国外转账,该如何操作? 32 46
485 我的外汇兑换收费是否比应收的要多? 17 34
159 有没有办法在应用程序里看到我的卡? 13 11

这些是 cleanlab 在此数据中识别出的非常明显的标签错误!请注意,`given_label` (给定标签) 并未正确反映这些请求的意图,制作此数据集的人犯了很多错误,在建模数据之前解决这些错误非常重要。

离群点问题

根据报告,我们的数据集包含一些离群点。我们可以通过 `get_issues` 查看哪些样本是离群点 (以及一个量化每个样本典型程度的数值质量得分)。我们通过 cleanlab 的离群点质量得分对结果数据帧进行排序,以查看数据集中最严重的离群点。

outlier_issues = lab.get_issues("outlier")
outlier_issues.sort_values("outlier_score").head()

输出会是这样的

is_outlier_issue (是否是离群点问题) outlier_score (离群点得分)
791 True 0.024866
601 True 0.031162
863 True 0.060738
355 True 0.064199
157 True 0.065075
lowest_quality_outliers = outlier_issues["outlier_score"].argsort()[:5]

data.iloc[lowest_quality_outliers]

质量最低的离群点的示例输出如下

索引 text (文本) label (标签)
791 提款待定是什么意思? 46
601 交易中有 1 美元的费用。 34
863 我的 atm 取款仍然待定 46
355 解释银行间汇率 32
157 丢失的卡找到了,想把它放回应用程序中 13

我们看到,cleanlab 已经识别出此数据集中的条目似乎不是正常的客户请求。此数据集中的离群点似乎是超出范围的客户请求以及其他无意义的文本,这对于意图分类没有意义。请仔细考虑这类离群点是否会对您的数据建模产生不利影响,如果是,请考虑将它们从数据集中移除。

近似重复问题

根据报告,我们的数据集包含一些近似重复的样本集。我们可以通过 `get_issues` 查看哪些样本是 (近似) 重复的 (以及一个量化每个样本与其数据集中最近邻居差异程度的数值质量得分)。我们通过 cleanlab 的近似重复质量得分对结果数据帧进行排序,以查看数据集中最接近重复的文本样本。

duplicate_issues = lab.get_issues("near_duplicate")
duplicate_issues.sort_values("near_duplicate_score").head()

以上结果显示了 cleanlab 认为近似重复的样本 (其中 `is_near_duplicate_issue == True` 的行)。在这里,我们看到样本 459 和 429 是近似重复的,样本 501 和 412 也是如此。

让我们查看这些样本,看看它们有多相似。

data.iloc[[459, 429]]

示例输出

索引 text (文本) label (标签)
459 我在国外买东西,但用了错误的汇率。 17
429 我在海外买东西,但用了错误的汇率。 17
data.iloc[[501, 412]]

示例输出

索引 text (文本) label (标签)
501 你们使用的汇率真的很差。这不可能是官方的银行间汇率。 17
412 你们使用的汇率很差。这不可能是官方的银行间汇率。 17

我们看到这两组请求确实非常相似!在数据集中包含近似重复项可能会对模型产生意想不到的影响,并且要警惕将它们分散在训练集和测试集中。从 常见问题解答 中了解更多关于处理数据集中近似重复项的信息。

非独立同分布 (Non-IID) 问题 (数据漂移)

根据报告,我们的数据集似乎不是独立同分布 (IID) 的。数据集的整体非 IID 得分 (如下所示) 对应于一个统计检验的 `p-value`,该检验用于判断数据集中样本的顺序是否与其特征值之间的相似性相关。较低的 `p-value` 强烈表明数据集违反了 IID 假设,这是从数据集中产生的结论 (模型) 能够泛化到更大总体的关键假设。

p_value = lab.get_info("non_iid")["p-value"]
p_value

在这里,我们的数据集被标记为非 IID,因为在原始数据中,行恰好按类别标签排序。如果我们记得在模型训练和数据分割之前对行进行洗牌,这可能是良性的。但如果您不知道为什么您的数据被标记为非 IID,那么您应该担心潜在的数据漂移或数据点之间的意外交互 (它们的值可能在统计上不是独立的)。仔细考虑未来的测试数据可能是什么样子 (以及您的数据是否代表您所关心的总体)。在非 IID 测试运行之前,您不应该对数据进行洗牌 (这会使其结论无效)。

如上所示,cleanlab 可以自动列出数据集中最可能的问题,帮助您更好地整理数据集以进行后续建模。有了这份问题清单,您可以决定是修复这些标签问题,还是从数据集中删除无意义或重复的样本,以获得更高质量的数据集来训练您的下一个机器学习模型。cleanlab 的问题检测可以与您最初训练的 *任何* 类型模型的输出一起运行。

Cleanlab 开源项目

Cleanlab 是一个标准的以数据为中心的 AI 软件包,旨在解决现实世界中混乱数据的数据质量问题。

请考虑给 Cleanlab Github 代码仓库一个 Star,我们欢迎对该项目做出贡献

< > 在 GitHub 上更新