音频课程文档

ASR 评估指标

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

ASR 评估指标

如果您熟悉 NLP 中的 Levenshtein 距离,那么评估语音识别系统的指标将很熟悉!如果您不熟悉,请不要担心,我们将从头到尾进行解释,以确保您了解不同的指标并理解它们的含义。

在评估语音识别系统时,我们将系统的预测与目标文本转录进行比较,并标注存在的任何错误。我们将这些错误分为以下三类:

  1. 替换 (S):预测中转录了**错误的词**(“sit” 而不是 “sat”)
  2. 插入 (I):预测中添加了**多余的词**
  3. 删除 (D):预测中**删除了一个词**

这些错误类别对于所有语音识别指标都是相同的。不同之处在于我们计算这些错误的级别:我们可以在*词级别*或*字符级别*计算它们。

我们将为每个指标定义使用一个运行示例。这里,我们有一个*地面真实*或*参考*文本序列

reference = "the cat sat on the mat"

以及我们正在评估的语音识别系统的预测序列

prediction = "the cat sit on the"

我们可以看到预测非常接近,但有些词不太正确。我们将针对三种最流行的语音识别指标评估此预测与参考的对比,看看我们为每种指标获得什么样的数字。

词错误率

词错误率 (WER) 指标是语音识别的“事实标准”指标。它计算*词级别*的替换、插入和删除。这意味着错误是逐词标注的。以我们的例子为例:

参考 the cat sat on the mat
预测 the cat sit on the
标签 S D

这里,我们有

  • 1 个替换(“sit” 而不是 “sat”)
  • 0 个插入
  • 1 个删除(缺少“mat”)

这总共产生了 2 个错误。为了得到错误率,我们将错误数量除以参考中的总词数 (N),本例中为 6。WER=S+I+DN=1+0+16=0.333 \begin{aligned} WER &= \frac{S + I + D}{N} \\ &= \frac{1 + 0 + 1}{6} \\ &= 0.333 \end{aligned}

好的!所以我们得到了 0.333 的 WER,或者说 33.3%。请注意,“sit”这个词只有一个字符是错误的,但整个词都被标记为不正确。这是 WER 的一个显著特征:拼写错误会受到严厉惩罚,无论它们多么微小。

WER 的定义是*越低越好*:较低的 WER 意味着我们的预测中错误较少,因此一个完美的语音识别系统将具有零 WER(无错误)。

让我们看看如何使用 🤗 Evaluate 计算 WER。我们需要两个包来计算我们的 WER 指标:🤗 Evaluate 用于 API 接口,JIWER 用于执行计算的繁重工作。

pip install --upgrade evaluate jiwer

太棒了!我们现在可以加载 WER 指标并计算我们示例的数字。

from evaluate import load

wer_metric = load("wer")

wer = wer_metric.compute(references=[reference], predictions=[prediction])

print(wer)

打印输出

0.3333333333333333

0.33,或 33.3%,正如预期!我们现在知道了 WER 计算的内部原理。

现在,有些事情很令人困惑……您认为 WER 的上限是多少?您会期望它是 1 或 100% 对吗?不!由于 WER 是错误与单词数 (N) 的比率,因此 WER 没有上限!让我们举一个例子,我们预测了 10 个单词,而目标只有 2 个单词。如果我们的所有预测都错了(10 个错误),我们将得到 10 / 2 = 5 的 WER,即 500%!如果您训练 ASR 系统并看到 WER 超过 100%,请记住这一点。不过如果您看到这种情况,很可能出了问题…… 😅

词准确率

我们可以将 WER 反过来,得到一个*越高越好*的指标。我们可以衡量系统的*词准确率 (WAcc)*,而不是衡量词错误率。WAcc=1WER \begin{equation} WAcc = 1 - WER \nonumber \end{equation}

WAcc 也以词级别进行测量,它只是 WER 的一种重新表达,将其作为准确率指标而不是错误指标。WAcc 在语音文献中很少被引用——我们通常以词错误来考虑我们的系统预测,因此更喜欢与这些错误类型标注相关的错误率指标。

字符错误率

将“sit”这个词完全标记为错误似乎有点不公平,因为实际上只有一个字母是错误的。那是因为我们是在词级别评估系统,因此逐词标注错误。*字符错误率 (CER)* 在*字符级别*评估系统。这意味着我们将单词分解成单独的字符,并逐个字符地标注错误。

参考 t h e c a t s a t o n t h e m a t
预测 t h e c a t s i t o n t h e
标签 S D D D D

我们现在可以看到,对于单词“sit”,字母“s”和“t”被标记为正确。只有“i”被标记为替换错误 (S)。因此,我们奖励系统部分正确的预测🤝。

在我们的示例中,我们有 1 个字符替换,0 个插入,和 4 个删除。总共有 22 个字符。因此,我们的 CER 是CER=S+I+DN=1+0+422=0.227 \begin{aligned} CER &= \frac{S + I + D}{N} \\ &= \frac{1 + 0 + 4}{22} \\ &= 0.227 \end{aligned}

没错!我们的 CER 为 0.227,即 22.7%。请注意,这比我们的 WER 低——我们对拼写错误的惩罚要小得多。

我应该使用哪个指标?

一般来说,WER 在评估语音系统时比 CER 使用得更广泛。这是因为 WER 要求系统更好地理解预测的上下文。在我们的例子中,“sit”是错误的语态。一个理解动词和句子语态之间关系的系统会预测出正确的动词语态“sat”。我们希望鼓励我们的语音系统达到这种理解水平。因此,尽管 WER 比 CER 更不宽容,但它也更有利于我们开发出更智能的系统。因此,我们通常使用 WER,也鼓励您这样做!然而,在某些情况下无法使用 WER。某些语言,如普通话和日语,没有“词”的概念,因此 WER 毫无意义。在这种情况下,我们转而使用 CER。

在我们的例子中,我们在计算 WER 时只使用了一个句子。在评估真实系统时,我们通常会使用包含数千个句子的整个测试集。在评估多个句子时,我们会汇总所有句子的 S、I、D 和 N,然后根据上面定义的公式计算 WER。这能更好地估计未见数据的 WER。

标准化

如果我们在带有标点符号和大小写的数据上训练 ASR 模型,它将学会预测转录中的大小写和标点符号。当我们想将模型用于实际的语音识别应用程序时(例如会议转录或口述),这非常有用,因为预测的转录将完全格式化,带有大小写和标点符号,这种风格被称为*正字法*。

然而,我们也可以选择对数据集进行*规范化*,以去除任何大小写和标点符号。规范化数据集使语音识别任务变得更容易:模型不再需要区分大小写字符,也不必仅凭音频数据预测标点符号(例如,分号发出什么声音?)。因此,词错误率自然会更低(这意味着结果更好)。Whisper 论文展示了规范化转录对 WER 结果的巨大影响(*参见* Whisper 论文的 4.4 节)。虽然我们得到了更低的 WER,但模型对于生产来说不一定更好。缺少大小写和标点符号使得模型预测的文本阅读起来明显更困难。以上一节的示例为例,我们在 LibriSpeech 数据集的相同音频样本上运行了 Wav2Vec2 和 Whisper 模型。Wav2Vec2 模型不预测标点符号和大小写,而 Whisper 则两者都预测。并排比较转录,我们发现 Whisper 的转录更容易阅读。

Wav2Vec2:  HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAUS AND ROSE BEEF LOOMING BEFORE US SIMALYIS DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND
Whisper:   He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly is drawn from eating and its results occur most readily to the mind.

Whisper 的转录是正字法的,因此可以直接使用——它按照我们期望的会议转录或口述脚本格式化,包括标点符号和大小写。相反,如果我们要将 Wav2Vec2 用于下游应用程序,则需要使用额外的后处理来恢复 Wav2Vec2 预测中的标点符号和大小写。

在规范化和不规范化之间有一个折衷方案:我们可以在正字法转录本上训练我们的系统,然后在计算 WER 之前规范化预测和目标。这样,我们训练我们的系统来预测完全格式化的文本,同时也能从规范化转录本所带来的 WER 改进中受益。

Whisper 模型发布时带有一个规范化器,它能有效地处理大小写、标点符号和数字格式等的规范化。让我们将规范化器应用于 Whisper 转录本,以演示我们如何规范化它们。

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

prediction = " He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly is drawn from eating and its results occur most readily to the mind."
normalized_prediction = normalizer(prediction)

normalized_prediction

输出

' he tells us that at this festive season of the year with christmas and roast beef looming before us similarly is drawn from eating and its results occur most readily to the mind '

太好了!我们可以看到文本已经完全小写,并且所有标点符号都已删除。现在让我们定义参考转录,然后计算参考和预测之间的标准化 WER。

reference = "HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND"
normalized_referece = normalizer(reference)

wer = wer_metric.compute(
    references=[normalized_referece], predictions=[normalized_prediction]
)
wer

输出

0.0625

6.25%——这大致是 Whisper 基础模型在 LibriSpeech 验证集上预期的结果。正如我们在这里看到的,我们预测了一个正字法转录,但通过在计算 WER 之前对参考和预测进行标准化,获得了 WER 的提升。

如何规范化转录的最终选择取决于您的需求。我们建议在正字法文本上进行训练,并在规范化文本上进行评估,以获得两全其美的效果。

整合所有内容

好的!本单元我们已经介绍了三个主题:预训练模型、数据集选择和评估。让我们来点乐趣,将它们整合到一个端到端的示例中 🚀 我们将通过在 Common Voice 13 Dhivehi 测试集上评估预训练的 Whisper 模型来为下一节的微调做好准备。我们将把获得的 WER 数字作为我们微调运行的*基线*,或者我们试图超越的目标 🥊

首先,我们将使用 pipeline() 类加载预训练的 Whisper 模型。这个过程现在应该非常熟悉了!我们唯一要做的就是,如果在 GPU 上运行,则以半精度 (float16) 模式加载模型——这将加速推理,且几乎不会牺牲 WER 精度。

from transformers import pipeline
import torch

if torch.cuda.is_available():
    device = "cuda:0"
    torch_dtype = torch.float16
else:
    device = "cpu"
    torch_dtype = torch.float32

pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    torch_dtype=torch_dtype,
    device=device,
)

接下来,我们将加载 Common Voice 13 的迪维希语测试分割。您可能还记得,在上一节中,Common Voice 13 是*受限的*,这意味着我们必须同意数据集使用条款才能访问数据集。现在我们可以将 Hugging Face 账户链接到我们的笔记本,这样我们就可以从当前使用的机器访问数据集了。

将笔记本链接到 Hub 很简单——只需在提示时输入您的 Hub 身份验证令牌即可。在此处查找您的 Hub 身份验证令牌,并在提示时输入。

from huggingface_hub import notebook_login

notebook_login()

太好了!一旦我们将笔记本链接到 Hugging Face 账户,就可以继续下载 Common Voice 数据集了。这将需要几分钟才能下载和预处理,它会从 Hugging Face Hub 获取数据并自动在您的笔记本上进行准备。

from datasets import load_dataset

common_voice_test = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="test"
)
如果您在加载数据集时遇到身份验证问题,请确保您已通过以下链接在 Hugging Face Hub 上接受了数据集的使用条款:https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0

对整个数据集进行评估与对单个示例进行评估非常相似——我们所要做的就是*遍历*输入音频,而不是只推理单个样本。为此,我们首先将数据集转换为 `KeyDataset`。这只是挑选出我们想要转发到模型的特定数据集列(在我们的例子中,是 `"audio"` 列),忽略其余的(例如我们不想用于推理的目标转录)。然后我们遍历这些转换后的数据集,将模型输出附加到一个列表中以保存预测。以下代码单元如果在 GPU 上以半精度运行,将大约需要五分钟,峰值内存占用为 12GB。

from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

all_predictions = []

# run streamed inference
for prediction in tqdm(
    pipe(
        KeyDataset(common_voice_test, "audio"),
        max_new_tokens=128,
        generate_kwargs={"task": "transcribe"},
        batch_size=32,
    ),
    total=len(common_voice_test),
):
    all_predictions.append(prediction["text"])
如果您在运行上述单元格时遇到 CUDA 内存不足 (OOM) 错误,请以 2 的倍数递减 `batch_size`,直到找到适合您设备的批处理大小。

最后,我们可以计算 WER。让我们首先计算正字法 WER,即不进行任何后处理的 WER。

from evaluate import load

wer_metric = load("wer")

wer_ortho = 100 * wer_metric.compute(
    references=common_voice_test["sentence"], predictions=all_predictions
)
wer_ortho

输出

167.29577268612022

好吧……167% 基本上意味着我们的模型输出的是垃圾 😜 别担心,我们的目标是通过在迪维希语训练集上微调模型来改进这一点!

接下来,我们将评估规范化 WER,即经过规范化后处理的 WER。我们必须过滤掉规范化后将为空的样本,否则我们的参考(N)中的总词数将为零,这将导致计算中出现除零错误。

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

# compute normalised WER
all_predictions_norm = [normalizer(pred) for pred in all_predictions]
all_references_norm = [normalizer(label) for label in common_voice_test["sentence"]]

# filtering step to only evaluate the samples that correspond to non-zero references
all_predictions_norm = [
    all_predictions_norm[i]
    for i in range(len(all_predictions_norm))
    if len(all_references_norm[i]) > 0
]
all_references_norm = [
    all_references_norm[i]
    for i in range(len(all_references_norm))
    if len(all_references_norm[i]) > 0
]

wer = 100 * wer_metric.compute(
    references=all_references_norm, predictions=all_predictions_norm
)

wer

输出

125.69809089960707

我们再次看到了通过规范化我们的参考文献和预测所实现的 WER 的大幅降低:基线模型的正字法测试 WER 为 168%,而规范化 WER 为 126%。

好的!这些就是我们微调模型时想要尝试超越的数字,以便改进 Whisper 模型在迪维希语语音识别方面的表现。继续阅读以动手实践微调示例 🚀

< > 在 GitHub 上更新