音频课程文档

ASR 评估指标

Hugging Face's logo
加入 Hugging Face 社区

并获取增强型文档体验

入门

ASR 评估指标

如果你熟悉 NLP 中的 莱文斯坦距离,那么评估语音识别系统的指标对你来说就很熟悉了!如果你不熟悉,也不用担心,我们会从头到尾解释,确保你了解不同的指标并理解它们的含义。

在评估语音识别系统时,我们将系统预测与目标文本转录进行比较,并标注任何出现的错误。我们将这些错误归类为以下三种类别之一:

  1. 替换 (S):在预测中转录了错误的词语(“sit”而不是“sat”)
  2. 插入 (I):在预测中添加了多余的词语
  3. 删除 (D):在预测中删除了词语

所有语音识别指标的错误类别都是相同的。不同之处在于我们计算这些错误的级别:我们可以按词语级别或按字符级别计算它们。

我们将使用每个指标定义的运行示例。在这里,我们有一个真实值参考文本序列

reference = "the cat sat on the mat"

以及来自我们试图评估的语音识别系统的预测序列

prediction = "the cat sit on the"

我们可以看到,预测非常接近,但有些词语并不完全正确。我们将评估该预测相对于参考值,针对三种最流行的语音识别指标,看看我们为每个指标获得了什么数字。

词错误率

词错误率 (WER) 指标是语音识别的“事实上的”指标。它在词语级别计算替换、插入和删除。这意味着错误是在逐字基础上标注的。以我们的示例为例

参考值 the cat sat on the mat
预测 the cat sit on the
标签 S D

在这里,我们有

  • 1 次替换(“sit”而不是“sat”)
  • 0 次插入
  • 1 次删除(“mat”缺失)

总共 2 个错误。为了获得我们的错误率,我们将错误数量除以参考值中的总词语数量 (N),在本示例中为 6WER=S+I+DN=1+0+16=0.333 \begin{aligned} WER &= \frac{S + I + D}{N} \\ &= \frac{1 + 0 + 1}{6} \\ &= 0.333 \end{aligned}

好的!因此,我们的 WER 为 0.333 或 33.3%。请注意,“sit”这个词只有一个字符是错误的,但整个词都被标记为错误。这是 WER 的一个定义特征:拼写错误会受到严重惩罚,无论其有多小。

WER 的定义是越低越好:较低的 WER 表示预测中的错误更少,因此,完美的语音识别系统将具有零 WER(无错误)。

让我们看看如何使用 🤗 Evaluate 计算 WER。我们需要两个包来计算我们的 WER 指标:🤗 Evaluate 用于 API 接口,JIWER 用于进行计算的繁重工作

pip install --upgrade evaluate jiwer

太好了!我们现在可以加载 WER 指标并计算我们示例的数字

from evaluate import load

wer_metric = load("wer")

wer = wer_metric.compute(references=[reference], predictions=[prediction])

print(wer)

打印输出

0.3333333333333333

0.33 或 33.3%,正如预期的那样!我们现在知道 WER 计算背后的工作原理。

现在,这里有一些让人困惑的东西……你认为 WER 的上限是多少?你可能会期望它为 1 或 100%?不!由于 WER 是错误数量与词语数量 (N) 的比率,因此 WER 没有上限!让我们举一个例子,我们预测了 10 个词,而目标值只有 2 个词。如果我们的所有预测都是错误的(10 个错误),那么我们的 WER 将为 10 / 2 = 5 或 500%!这一点需要牢记在心,如果你训练了 ASR 系统并看到 WER 超过 100%。不过,如果你看到这种情况,很可能出了一些问题……😅

词准确率

我们可以将 WER 翻转过来,得到一个越高越好的指标。与其衡量词错误率,不如衡量系统的词准确率 (WAcc)WAcc=1WER \begin{equation} WAcc = 1 - WER \nonumber \end{equation}

WAcc 也是在词级别上衡量的,只是 WER 被重新表述为一个准确率指标,而不是一个错误指标。WAcc 在语音文献中很少被引用 - 我们从词错误的角度思考系统的预测结果,因此更喜欢与这些错误类型标注相关的错误率指标。

字符错误率

我们把整个单词 “sit” 标注为错误似乎有点不公平,因为实际上只有一个字母是错误的。这是因为我们是在词级别上评估系统,因此逐词标注错误。字符错误率 (CER)字符级别上评估系统。这意味着我们将单词分解成单个字符,并在字符级别上标注错误。

参考值 t h e c a t s a t o n t h e m a t
预测 t h e c a t s i t o n t h e
标签 S D D D

现在我们可以看到,对于单词 “sit”,“s” 和 “t” 被标记为正确。只有 “i” 被标记为替换错误 (S)。因此,我们奖励系统对部分正确预测的识别🤝

在我们的例子中,我们有 1 个字符替换,0 个插入和 3 个删除。总共有 14 个字符。所以,我们的 CER 是CER=S+I+DN=1+0+314=0.286 \begin{aligned} CER &= \frac{S + I + D}{N} \\ &= \frac{1 + 0 + 3}{14} \\ &= 0.286 \end{aligned}

没错!我们的 CER 是 0.286,也就是 28.6%。请注意,这低于我们的 WER - 我们对拼写错误的惩罚要轻得多。

我应该使用哪个指标?

一般来说,WER 比 CER 更常用于评估语音系统。这是因为 WER 要求系统对预测结果的上下文有更深的理解。在我们的例子中,“sit” 的时态是错误的。一个理解动词和句子时态之间关系的系统应该预测出正确的动词时态 “sat”。我们希望鼓励我们的语音系统达到这种理解水平。因此,尽管 WER 比 CER 更加严苛,但它也有利于我们希望开发的更易理解的系统。因此,我们通常使用 WER,也建议你使用 WER!但是,在某些情况下,无法使用 WER。某些语言,如普通话和日语,没有“词”的概念,因此 WER 毫无意义。在这种情况下,我们将使用 CER。

在我们的例子中,我们在计算 WER 时只使用了一个句子。在评估真实系统时,我们通常会使用包含数千个句子的完整测试集。在评估多个句子时,我们将所有句子的 S、I、D 和 N 汇总,然后根据上面定义的公式计算 WER。这将更好地估计看不见数据的 WER。

规范化

如果我们用包含标点符号和大小写的语料库训练一个 ASR 模型,它将学习在转录中预测大小写和标点符号。当我们希望将模型用于实际的语音识别应用(如转录会议或听写)时,这很好,因为预测的转录结果将完全格式化,包括大小写和标点符号,这种风格被称为正字法

然而,我们也可以选择规范化数据集,以去除任何大小写和标点符号。规范化数据集可以使语音识别任务变得更容易:模型不再需要区分大小写字符,也不需要单独从音频数据中预测标点符号(例如,分号是什么声音?)。因此,词错误率自然会降低(这意味着结果更好)。Whisper 论文展示了规范化转录对 WER 结果的显著影响(参阅 Whisper 论文的第 4.4 节)。虽然我们可以获得更低的 WER,但模型并不一定更适合生产环境。缺乏大小写和标点符号会使模型预测的文本难以阅读。以上一节中的例子为例,我们在 LibriSpeech 数据集的同一个音频样本上运行了 Wav2Vec2 和 Whisper。Wav2Vec2 模型既不预测标点符号也不预测大小写,而 Whisper 则预测两者。将转录文本并排比较,我们可以看到 Whisper 的转录文本更容易阅读。

Wav2Vec2:  HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAUS AND ROSE BEEF LOOMING BEFORE US SIMALYIS DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND
Whisper:   He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly is drawn from eating and its results occur most readily to the mind.

Whisper 的转录文本是正字法的,因此可以直接使用 - 它以我们期望的格式进行格式化,如会议转录或听写脚本,包括标点符号和大小写。相反,如果我们想将 Wav2Vec2 用于下游应用,则需要使用额外的后处理来恢复标点符号和大小写。

规范化和不规范化之间存在一个折衷方案:我们可以用正字法转录训练我们的系统,然后在计算 WER 之前规范化预测和目标。这样,我们就可以训练我们的系统来预测格式完整的文本,同时也能从规范化转录带来的 WER 改进中获益。

Whisper 模型发布了一个规范化器,可以有效地处理大小写、标点符号和数字格式等的规范化。让我们将规范化器应用于 Whisper 转录文本,以演示如何规范化它们。

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

prediction = " He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly is drawn from eating and its results occur most readily to the mind."
normalized_prediction = normalizer(prediction)

normalized_prediction

输出

' he tells us that at this festive season of the year with christmas and roast beef looming before us similarly is drawn from eating and its results occur most readily to the mind '

太好了!我们可以看到文本已全部转换为小写,并且所有标点符号都被删除。现在让我们定义参考转录文本,然后计算参考文本和预测文本之间的规范化 WER。

reference = "HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND"
normalized_referece = normalizer(reference)

wer = wer_metric.compute(
    references=[normalized_referece], predictions=[normalized_prediction]
)
wer

输出

0.0625

6.25% - 这大约是我们在 LibriSpeech 验证集上对 Whisper 基础模型的预期结果。正如我们在这里看到的,我们预测了一个正字法转录文本,但从在计算 WER 之前规范化参考文本和预测文本所获得的 WER 提升中受益。

最终,如何规范化转录文本取决于您的需求。我们建议用正字法文本训练,用规范化文本评估,以获得两全其美。

综合概述

好了!我们已经在这单元中涵盖了三个主题:预训练模型、数据集选择和评估。让我们来点有趣的,将它们整合到一个端到端的例子中 🚀 我们将通过评估预训练的 Whisper 模型在 Common Voice 13 Dhivehi 测试集上的表现来为下一节微调做准备。我们将使用得到的 WER 数值作为微调运行的基线,或者作为我们努力超越的目标 🥊

首先,我们将使用 pipeline() 类加载预训练的 Whisper 模型。这个过程现在应该非常熟悉了!我们将做的唯一不同的事情是,如果在 GPU 上运行,则以半精度(float16)加载模型 - 这将以几乎不影响 WER 准确性的代价加速推理速度。

from transformers import pipeline
import torch

if torch.cuda.is_available():
    device = "cuda:0"
    torch_dtype = torch.float16
else:
    device = "cpu"
    torch_dtype = torch.float32

pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    torch_dtype=torch_dtype,
    device=device,
)

接下来,我们将加载 Common Voice 13 的 Dhivehi 测试集。您会记得上一节中提到 Common Voice 13 是受限的,这意味着我们在访问数据集之前必须同意数据集的使用条款。现在,我们可以将我们的 Hugging Face 帐户链接到我们的笔记本,以便我们能够从当前使用的机器访问数据集。

将笔记本链接到 Hub 很简单 - 只需要在提示时输入您的 Hub 认证令牌即可。您可以在此处找到您的 Hub 认证令牌,并在提示时输入。

from huggingface_hub import notebook_login

notebook_login()

很好!一旦将笔记本链接到我们的 Hugging Face 帐户,我们就可以继续下载 Common Voice 数据集。这将需要几分钟来下载和预处理,从 Hugging Face Hub 获取数据并在您的笔记本上自动准备它。

from datasets import load_dataset

common_voice_test = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="test"
)
如果您在加载数据集时遇到认证问题,请确保您已通过以下链接在 Hugging Face Hub 上接受了数据集的使用条款:https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0。

对整个数据集进行评估与对单个示例进行评估的方式基本相同 - 我们只需要循环遍历输入音频,而不是仅推断单个样本。为此,我们首先将数据集转换为 KeyDataset。它所做的只是选出我们要转发给模型的特定数据集列(在本例中为 "audio" 列),而忽略其他列(例如目标转录文本,我们不想用于推理)。然后,我们遍历这个转换后的数据集,将模型输出追加到列表中以保存预测结果。如果在具有半精度的 GPU 上运行,以下代码单元将大约需要五分钟,内存峰值约为 12GB。

from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

all_predictions = []

# run streamed inference
for prediction in tqdm(
    pipe(
        KeyDataset(common_voice_test, "audio"),
        max_new_tokens=128,
        generate_kwargs={"task": "transcribe"},
        batch_size=32,
    ),
    total=len(common_voice_test),
):
    all_predictions.append(prediction["text"])
如果您在运行上面的代码单元时遇到 CUDA 内存不足 (OOM) 错误,请将 batch_size 按 2 的倍数递减,直到找到适合您设备的批次大小。

最后,我们可以计算 WER。首先,我们计算正字法 WER,即不进行任何后处理的 WER。

from evaluate import load

wer_metric = load("wer")

wer_ortho = 100 * wer_metric.compute(
    references=common_voice_test["sentence"], predictions=all_predictions
)
wer_ortho

输出

167.29577268612022

好吧……167% 基本上意味着我们的模型输出的是垃圾 😜 别担心,我们的目标是通过在 Dhivehi 训练集上微调模型来改进这一点!

接下来,我们将评估规范化 WER,即进行规范化后处理的 WER。我们必须过滤掉在规范化后会为空的样本,否则我们参考文本中的总词数 (N) 将为零,这将导致我们的计算中出现除零错误。

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

# compute normalised WER
all_predictions_norm = [normalizer(pred) for pred in all_predictions]
all_references_norm = [normalizer(label) for label in common_voice_test["sentence"]]

# filtering step to only evaluate the samples that correspond to non-zero references
all_predictions_norm = [
    all_predictions_norm[i]
    for i in range(len(all_predictions_norm))
    if len(all_references_norm[i]) > 0
]
all_references_norm = [
    all_references_norm[i]
    for i in range(len(all_references_norm))
    if len(all_references_norm[i]) > 0
]

wer = 100 * wer_metric.compute(
    references=all_references_norm, predictions=all_predictions_norm
)

wer

输出

125.69809089960707

我们再次看到了通过规范化参考文本和预测文本所获得的 WER 显著降低:基线模型的正字法测试 WER 为 168%,而规范化 WER 为 126%。

好了!这些是我们希望在微调模型时超越的数字,以便改进 Whisper 模型的 Dhivehi 语音识别性能。继续阅读,动手操作微调示例 🚀