使用推测解码让 Whisper 推理速度翻倍

发布于 2023 年 12 月 20 日
在 GitHub 上更新
Open In Colab

OpenAI 的 Whisper 是一款通用语音转录模型,它在一系列不同的基准测试和音频条件下均取得了最先进的成果。最新的 large-v3 模型在 OpenASR 排行榜上名列前茅,被评为英语领域最佳的开源语音转录模型。该模型还表现出强大的多语言性能,在 Common Voice 15 数据集测试的 58 种语言中,有 42 种语言的词错误率 (WER) 低于 30%。

尽管转录准确率非常出色,但其推理速度却很慢。即使利用了 Flash Attention、半精度和分块 (chunking) 等推理优化技术,在 16GB T4 GPU 上转录 1 小时的音频片段也需要 6 分钟以上。

在这篇博文中,我们将演示如何利用推测解码 (Speculative Decoding) 将 Whisper 的推理时间减少 2 倍,同时从数学上确保模型获得完全相同的输出。因此,该方法可以完美地替代现有的 Whisper 工作流,因为它在保持相同准确率的同时,免费提供了 2 倍的速度提升。如果想看一个解释更少但包含所有代码的精简版博文,请参阅配套的 Google Colab

推测解码

推测解码由 Google 的 Yaniv Leviathan 等人在论文 《Fast Inference from Transformers via Speculative Decoding》 中提出。其工作原理基于这样一个前提:一个更快的**辅助模型**通常会生成与一个更大的**主模型**相同的词元 (token)。

首先,辅助模型自回归地生成一个包含 N N 个**候选词元**的序列,y^1:N \hat{\boldsymbol{y}}_{1:N} 。在下图中,辅助模型生成了 5 个候选词元组成的序列:The quick brown sock jumps

虽然这些候选词元生成得很快,但它们可能与主模型预测的词元不同。因此,在第二步中,这些候选词元被传递给主模型进行“验证”。主模型将候选词元作为输入,并执行**单次前向传播**。主模型的输出是词元序列 y1:N \boldsymbol{y}_{1:N} 中每一步的“正确”词元。

在上图中,我们看到主模型预测的前三个词元与辅助模型的预测一致:The quick brown。然而,辅助模型的第四个候选词元 sock 与主模型的正确词元 fox 不匹配。

我们知道,在第一个不匹配出现之前的所有候选词元(The quick brown)都是正确的,因为它们与主模型的预测一致。然而,在第一个不匹配之后,候选词元就与主模型预测的实际词元产生了分歧。因此,我们可以用主模型的正确词元(fox)替换掉第一个不正确的候选词元(sock),并丢弃其后所有预测的词元,因为它们已经偏离了。修正后的序列 The quick brown fox 现在成为辅助模型的新输入。

然后推理过程重复,辅助模型生成一组新的 N N 个候选词元,再由主模型通过一次前向传播进行验证。

由于我们使用快速的辅助模型进行自回归生成,而只用慢速的主模型进行验证性的前向传播,解码过程得以大幅加速。此外,主模型执行的验证性前向传播确保了我们能获得与单独使用主模型时**完全相同的输出**。这使得推测解码成为现有 Whisper 工作流的完美替代品,因为可以确信能达到同样的质量。

为了最大程度地减少延迟,辅助模型应该比主模型快得多,同时尽可能频繁地预测出相同的词元分布。实际上,这两个属性之间存在一种权衡:模型越快,其准确性就越低。然而,由于 70-80% 的预测词元往往是“较简单”的词元,这种权衡倾向于选择更快的模型,而不是更准确的模型。因此,辅助模型应该比主模型快至少 3 倍(越快越好),同时能正确预测示例中所有“简单”的词元。剩下的 20-30% 更“困难”的词元则可以由更大的主模型来验证。

选择辅助模型的唯一限制是它必须与主模型共享相同的词汇表。也就是说,辅助模型必须使用与主模型完全相同的分词器 (tokenizer)。因此,如果我们想对 Whisper 的多语言版本(例如 large-v2 (多语言))使用推测解码,我们需要选择一个 Whisper 的多语言版本作为辅助模型,例如 tiny。而如果我们想对 Whisper 的纯英文版本(例如 medium.en)使用推测解码,则需要一个纯英文版本的辅助模型,例如 tiny.en。目前,Whisper large-v3 是一个例外,因为它是唯一一个词汇表大小经过扩展的 Whisper 模型检查点,因此与之前的 Whisper 模型检查点不兼容。

现在我们了解了推测解码的背景知识,可以开始进行实际操作了。在 🤗 Transformers 库中,推测解码被实现为“辅助生成”(assisted generation) 推理策略。有关该实现的更多详细信息,建议读者阅读 Joao Gante 撰写的关于辅助生成的精彩博文。

英语语音转录

基准实现

我们首先对 Whisper large-v2 进行基准测试,以获取推理速度的基准数据。我们可以通过便捷的 AutoModelForSpeechSeq2SeqAutoProcessor 类来加载主模型及其对应的处理器。我们将以 float16 精度加载模型,并通过传递 low_cpu_mem_usage=True 来确保加载时间尽可能短。此外,我们希望确保模型以 safetensors 格式加载,因此传递 use_safetensors=True。最后,我们将传递参数 attn_implementation="sdpa",以通过 PyTorch 的 SDPA 注意力核来利用 Flash Attention 的加速效果。

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

让我们加载用于基准测试的英语语音转录数据集。我们将从 LibriSpeech ASR validation-clean 数据集中加载一个包含 73 个样本的小型数据集。这大约相当于 9MB 的数据,因此非常轻量级,可以快速下载到设备上。

from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

对于基准测试,我们只想测量生成时间,所以让我们编写一个简短的辅助函数来测量这一步。以下函数将返回解码后的词元和运行模型所花费的时间。

import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

我们现在可以遍历数据集中的音频样本,并累加总生成时间。

from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

输出

100%|██████████| 73/73 [01:37<00:00,  1.33s/it]
72.99542546272278

好的!我们看到转录 73 个样本花费了 73 秒。现在来检查一下预测结果的词错误率 (WER)。

from evaluate import load

wer = load("wer")
print(wer.compute(predictions=predictions, references=references))

输出

0.03507271171941831

我们的最终基准数据是 73 秒的运行时间和 3.5% 的词错误率。

推测解码

现在让我们加载用于推测解码的辅助模型。在这个例子中,我们将使用 Whisper 的一个蒸馏变体,distil-large-v2。这个蒸馏模型复制了 Whisper 的整个编码器,但解码器层数从 32 层减少到了 2 层。因此,它的运行速度比 Whisper 快 6 倍,而在非分布测试集上的词错误率 (WER) 仅相差 1% 以内。这使其成为辅助模型的完美选择,因为它兼具高转录准确率和快速生成速度的优点1{}^1

由于 Distil-Whisper 使用与 Whisper 模型完全相同的编码器,我们可以在主模型和辅助模型之间共享编码器。这样,我们只需要将 Distil-Whisper 的 2 层解码器作为“仅解码器”模型加载。我们可以通过便捷的 AutoModelForCausalLM auto 类来完成此操作。实际上,这只会比单独使用主模型增加 8% 的显存占用。

from transformers import AutoModelForCausalLM

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device)

1{}^1 我们计划发布一个改进版的 Distil-Whisper,它在词元分布上具有更强的一致性,这将进一步提高推测解码的性能。请关注 Distil-Whisper 代码库以获取更新。


我们可以为我们的推测解码基准测试定义一个修改过的函数。与之前的函数唯一的区别是,我们在调用 .generate 时传递了辅助模型。

def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

让我们使用 Distil-Whisper 作为 Whisper 的辅助模型来运行推测解码的基准测试。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

输出

100%|██████████| 73/73 [00:38<00:00,  1.88it/s]
32.69683289527893

使用推测解码后,推理时间仅为 33 秒,比之前快了 2.2 倍!让我们验证一下词错误率 (WER) 是否相同。

print(wer.compute(predictions=predictions, references=references))

输出

0.03507271171941831

完美!词错误率 (WER) 仍然是 3.5%,因为我们得到了与单独使用主模型时完全相同的输出。

推测解码也可以与简单易用的 🤗 Transformers pipeline API 一起使用进行推理。下面,我们使用模型和处理器实例化 pipeline,然后用它来转录玩具数据集中的第一个样本。这种方法可以扩展到转录任意长度的音频样本,包括使用批处理。

from transformers import pipeline

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=4,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

输出

 Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

Distil-Whisper 模型卡上可以找到一个端到端的代码片段,用于运行 Whisper 和 Distil-Whisper 的推测解码。它将本笔记本中涵盖的推理阶段整合到一个代码示例中。

多语言语音转录

Distil-Whisper 是英语语音转录的理想辅助模型,因为它在短音频和长音频样本上的性能与原始 Whisper 模型的词错误率 (WER) 相差不到 1%,而速度却快了 6 倍。然而,官方的 Distil-Whisper 模型检查点仅支持英语,这意味着它们不能用于多语言语音转录。

要将推测解码用于多语言语音转录,可以使用官方的多语言 Whisper 模型检查点之一,或者一个经过微调的 Whisper 变体。在撰写本文时,Hugging Face Hub 上有超过 5000 个经过微调的 Whisper 模型检查点,涵盖 100 多种语言。这些为选择在单一语言上表现出色的辅助 Whisper 模型检查点提供了绝佳的起点。在本例中,我们将使用最小的官方多语言模型检查点,即 Whisper tiny。欢迎你尝试在你所用语言上微调过的不同模型检查点!

让我们加载新辅助模型 Whisper tiny 的权重。由于 Whisper tiny 中的编码器与 large-v2 中的不同,这次我们将使用 AutoModelForSpeechSeq2Seq 类同时加载编码器和解码器。

assistant_model_id = "openai/whisper-tiny"

assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

对于我们的基准测试数据集,我们将从 VoxPopuli 数据集的荷兰语 ("nl") 部分加载 73 个样本。

dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

太棒了!我们现在可以像之前一样,为我们的基准 Whisper large-v2 模型重新运行基准测试。唯一的变化是,我们向 generate 函数传递了语言和任务参数,以确保我们执行的是语音转录(而不是语音翻译)。推测解码与语音转录和语音翻译任务完全兼容。只需根据需要设置任务参数即可,如下所示。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

输出

100%|██████████| 73/73 [02:05<00:00,  1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146

好的!我们的基准时间是 117 秒,词错误率 (WER) 为 12.8%。现在让我们使用推测解码重新运行生成过程。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

输出

100%|██████████| 73/73 [01:08<00:00,  1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146

我们再次达到了 12.8% 的词错误率 (WER),但这次推理时间仅为 62 秒,速度提升了 1.9 倍。考虑到加载辅助模型的开销很低,并且在数学上能保证获得完全相同的输出,推测解码为现有的 Whisper 工作流提供了一个完美的替代方案。

高效推测解码策略

在最后一部分,我们将介绍两种策略,以确保使用推测解码时能获得最快的推理速度。

辅助模型

我们的目标是选择一个比主模型快至少 3 倍**并且**能正确转录至少 70-80% 预测词元的辅助模型,这些词元通常是示例中“较简单”的词元。如果你有特定语言的转录需求,一个有效的策略是训练两个不同大小的 Whisper 模型,并用一个作为另一个的辅助模型。

  • 首先,微调 Whisper large-v3 作为你的主模型。
  • 其次,在相同的数据集上蒸馏 Whisper large-v3,以作为快速的辅助模型。

微调和蒸馏可以提高主模型和辅助模型在你所选语言上的词错误率 (WER) 性能,同时最大化词元分布的一致性。关于 Whisper 微调的完整指南可以在这里找到,蒸馏的指南在这里

批处理大小

值得注意的是,推测解码在批处理大小为 1 时能获得最大的速度提升。对于批处理的推测解码,**批次中所有**候选词元都必须与验证词元匹配,这些词元才会被接受。如果批次中某个位置的词元不一致,那么该位置之后的所有候选词元都将被丢弃。因此,推测解码更适合较小的批处理大小。在实践中,我们发现推测解码在批处理大小达到 4 之前都能提供加速效果。当批处理大小超过 4 时,推测解码的推理速度会比单独使用主模型还要慢。完整结果请参考 Distil-Whisper 论文的 D.3 节。

结论

在这篇博文中,我们介绍了应用于 Whisper 模型进行语音转录的推测解码推理策略。我们展示了如何实现 2 倍的速度提升,同时在数学上确保输出与单独使用原始模型完全相同。我们鼓励你尝试使用推测解码作为现有 Whisper 工作流的直接替代方案,因为它使用额外辅助模型的开销很低,并且能保证获得相同的转录结果。

致谢

博文作者 Sanchit Gandhi。非常感谢 Patrick von PlatenPedro Cuenca 提出的建设性意见,以及 Joao Gante 在 🤗 Transformers 中实现的辅助生成功能。

社区

注册登录以发表评论