使用 🤗 Transformers 微调 Whisper 模型以实现多语言 ASR
在本博客中,我们使用 Hugging Face 🤗 Transformers 为任意多语言 ASR 数据集微调 Whisper 提供了分步指南。本博客深入解释了 Whisper 模型、Common Voice 数据集和微调背后的理论,并附有代码单元格来执行数据准备和微调步骤。如果需要一个解释较少但包含所有代码的精简版笔记,请参阅随附的 Google Colab。
目录
引言
Whisper 是 OpenAI 的 Alec Radford 等作者于 2022 年 9 月发布的用于自动语音识别 (ASR) 的预训练模型。与许多前辈模型(如 Wav2Vec 2.0)不同,后者是在未标记的音频数据上进行预训练,而 Whisper 是在大量已标记的音频转录数据上进行预训练的,准确地说是 68 万小时。这比用于训练 Wav2Vec 2.0 的未标记音频数据(6 万小时)要多一个数量级。更重要的是,这 68 万小时的预训练数据中有 11.7 万小时是多语言 ASR 数据。这使得模型检查点可以应用于超过 96 种语言,其中许多被认为是低资源语言。
如此大量的标注数据使 Whisper 能够直接在语音识别的监督任务上进行预训练,从标注的音频-转录预训练数据中学习语音到文本的映射关系。因此,Whisper 只需要很少的额外微调就能产出一个高性能的 ASR 模型。这与 Wav2Vec 2.0 形成对比,后者是在无监督的掩码预测任务上预训练的。在这种情况下,模型被训练来仅从无标签的音频数据中学习从语音到隐藏状态的中间映射。虽然无监督预训练能产生高质量的语音表示,但它不会学习语音到文本的映射。这个映射只在微调期间学习,因此需要更多的微调才能获得有竞争力的性能。
当扩展到 68 万小时的标记预训练数据时,Whisper 模型展现出强大的泛化能力,能够适应多种数据集和领域。这些预训练检查点在性能上能与最先进的 ASR 系统相媲美,在 LibriSpeech ASR 的 test-clean 子集上达到了接近 3% 的词错误率(WER),并在 TED-LIUM 上以 4.7% 的 WER 创造了新的 SOTA(c.f. Whisper 论文的表 8)。Whisper 在预训练期间获得的广泛多语言 ASR 知识可以被用于其他低资源语言;通过微调,预训练的检查点可以适应特定的数据集和语言,以进一步提升这些结果。
Whisper 是一个基于 Transformer 的编码器-解码器模型,也称为序列到序列模型。它将音频频谱特征的序列映射到文本词元的序列。首先,原始音频输入通过特征提取器的作用被转换为对数梅尔频谱图。然后,Transformer 编码器对频谱图进行编码,形成一个编码器隐藏状态的序列。最后,解码器自回归地预测文本词元,其条件是之前的词元和编码器的隐藏状态。图 1 概述了 Whisper 模型。
在序列到序列模型中,编码器将音频输入转换为一组隐藏状态表示,从语音中提取重要特征。解码器扮演语言模型的角色,处理隐藏状态表示并生成相应的文本转录。在系统架构中内部集成语言模型被称为深度融合。这与浅层融合形成对比,后者是将语言模型与编码器外部结合,例如 CTC + -gram(c.f. Internal Language Model Estimation)。通过深度融合,整个系统可以使用相同的训练数据和损失函数进行端到端训练,从而提供更大的灵活性和通常更优越的性能(c.f. ESB Benchmark)。
Whisper 使用交叉熵目标函数进行预训练和微调,这是训练序列到序列系统进行分类任务的标准目标函数。在这里,系统被训练以从预定义的文本词元词汇表中正确分类目标文本词元。
Whisper 检查点有五种不同模型大小的配置。最小的四种是在纯英语或多语言数据上训练的。最大的检查点仅为多语言。所有 11 个预训练检查点都可以在 Hugging Face Hub 上找到。下表总结了这些检查点,并附有 Hub 上模型的链接。
大小 | 层数 | 宽度 | 注意力头数 | 参数量 | 仅英语 | 多语言 |
---|---|---|---|---|---|---|
tiny | 4 | 384 | 6 | 39 M | ✓ | ✓ |
base | 6 | 512 | 8 | 74 M | ✓ | ✓ |
small | 12 | 768 | 12 | 244 M | ✓ | ✓ |
medium | 24 | 1024 | 16 | 769 M | ✓ | ✓ |
large | 32 | 1280 | 20 | 1550 M | x | ✓ |
large-v2 | 32 | 1280 | 20 | 1550 M | x | ✓ |
large-v3 | 32 | 1280 | 20 | 1550 M | x | ✓ |
出于演示目的,我们将微调多语言版本的 small
检查点,其参数量为 244M (约 1GB)。至于我们的数据,我们将在 Common Voice 数据集中的一种低资源语言上训练和评估我们的系统。我们将展示,仅用 8 小时的微调数据,我们就能在该语言上取得强大的性能。
Whisper 这个名字来源于缩写“WSPSR”,代表“Web-scale Supervised Pre-training for Speech Recognition”(网络规模的监督式语音识别预训练)。
在 Google Colab 中微调 Whisper
准备环境
我们将使用几个流行的 Python 包来微调 Whisper 模型。我们将使用 `datasets[audio]` 来下载和准备我们的训练数据,同时使用 `transformers` 和 `accelerate` 来加载和训练我们的 Whisper 模型。我们还需要 `soundfile` 包来预处理音频文件,`evaluate` 和 `jiwer` 来评估我们模型的性能,以及 `tensorboard` 来记录我们的指标。最后,我们将使用 `gradio` 为我们微调的模型构建一个酷炫的演示。
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio
我们强烈建议您在训练期间将模型检查点直接上传到 Hugging Face Hub。Hub 提供:
- 集成的版本控制:你可以确保在训练过程中不会丢失任何模型检查点。
- Tensorboard 日志:在训练过程中跟踪重要指标。
- 模型卡片:记录模型的功能及其预期用途。
- 社区:与社区分享和协作的便捷方式!
将笔记本链接到 Hub 非常简单——只需在提示时输入您的 Hub 身份验证令牌即可。在此处找到您的 Hub 身份验证令牌 here
from huggingface_hub import notebook_login
notebook_login()
打印输出
Login successful
Your token has been saved to /root/.huggingface/token
加载数据集
Common Voice 是一系列众包数据集,其中说话者用各种语言录制维基百科的文本。我们将使用撰写本文时 Common Voice 数据集的最新版本(版本 11)。至于我们的语言,我们将在 印地语上微调我们的模型,这是一种在印度北部、中部、东部和西部使用的印度-雅利安语。Common Voice 11.0 包含大约 12 小时的已标记印地语数据,其中 4 小时是保留的测试数据。
提示:您可以通过查看 Hugging Face Hub 上的 Mozilla Foundation 组织页面来找到 Common Voice 数据集的最新版本。较新版本涵盖更多语言,并且每种语言包含更多数据。
让我们前往 Hub 并查看 Common Voice 的数据集页面:mozilla-foundation/common_voice_11_0。
我们第一次查看此页面时,会被要求接受使用条款。之后,我们将获得对数据集的完全访问权限。
一旦我们提供了使用数据集的身份验证,我们将看到数据集预览。数据集预览向我们展示了数据集的前 100 个样本。更重要的是,它加载了音频样本,供我们实时收听。我们可以通过下拉菜单将子集设置为 `hi` 来选择 Common Voice 的印地语子集(`hi` 是印地语的语言标识符代码)。

如果我们点击第一个样本的播放按钮,我们可以收听音频并查看相应的文本。请滚动浏览训练集和测试集的样本,以便更好地了解我们正在处理的音频和文本数据。从语调和风格可以看出,这些录音来自旁白语音。您可能还会注意到说话者和录音质量的巨大差异,这是众包数据的常见特征。
使用 🤗 Datasets,下载和准备数据非常简单。我们只需一行代码就可以下载和准备 Common Voice 的各个数据分割。由于印地语资源非常有限,我们将合并 `train` 和 `validation` 分割,以提供大约 8 小时的训练数据。我们将使用 4 小时的 `test` 数据作为我们的保留测试集。
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)
print(common_voice)
打印输出
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 6540
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 2894
})
})
大多数 ASR 数据集只提供输入音频样本(`audio`)和相应的转录文本(`sentence`)。Common Voice 包含额外的元数据信息,如 `accent` 和 `locale`,我们可以在 ASR 中忽略这些信息。为了使笔记本尽可能通用,我们只考虑输入音频和转录文本进行微调,丢弃额外的元数据信息。
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
Common Voice 只是我们可以从 Hub 下载的众多多语言 ASR 数据集之一——还有更多可供我们选择!要查看可用于语音识别的数据集范围,请点击链接:Hub 上的 ASR 数据集。
准备特征提取器、分词器和数据
ASR 流程可以分解为三个部分:
- 一个对原始音频输入进行预处理的特征提取器。
- 执行序列到序列映射的模型
- 一个将模型输出后处理成文本格式的分词器
在 🤗 Transformers 中,Whisper 模型有一个相关的特征提取器和分词器,分别称为 WhisperFeatureExtractor 和 WhisperTokenizer。
我们将逐一详细介绍特征提取器和分词器!
加载 WhisperFeatureExtractor
语音由一个随时间变化的一维数组表示。数组在任何给定时间步的值是信号在该点的振幅。仅凭振幅信息,我们就可以重建音频的频谱并恢复所有的声学特征。
由于语音是连续的,它包含无限多个振幅值。这给期望有限数组的计算机设备带来了问题。因此,我们通过在固定的时间步长从信号中采样值来离散化我们的语音信号。我们采样音频的间隔称为采样率,通常以样本/秒或赫兹 (Hz) 为单位。以更高的采样率采样可以更好地逼近连续的语音信号,但每秒也需要存储更多的值。
将我们的音频输入采样率与模型期望的采样率相匹配至关重要,因为具有不同采样率的音频信号具有非常不同的分布。音频样本应始终以正确的采样率进行处理。否则可能导致意想不到的结果!例如,以 16kHz 的采样率采集音频样本,并以 8kHz 的采样率播放,会使音频听起来像是半速播放。同样,传递错误采样率的音频会使期望一种采样率而接收到另一种采样率的 ASR 模型失效。Whisper 特征提取器期望音频输入的采样率为 16kHz,因此我们需要将输入匹配到这个值。我们不希望无意中用慢动作语音训练 ASR 系统!
Whisper 特征提取器执行两个操作。它首先对一批音频样本进行填充/截断,使所有样本的输入长度都为 30 秒。短于 30 秒的样本通过在序列末尾附加零来填充到 30 秒(音频信号中的零对应于无信号或静音)。长于 30 秒的样本被截断为 30 秒。由于批次中的所有元素都在输入空间中被填充/截断到最大长度,因此在将音频输入转发给 Whisper 模型时,我们不需要注意力掩码。Whisper 在这方面是独特的——对于大多数音频模型,您需要提供一个注意力掩码,详细说明序列被填充的位置,从而在自注意力机制中应被忽略的位置。Whisper 被训练为在没有注意力掩码的情况下运行,并直接从语音信号中推断出忽略输入的位置。
Whisper 特征提取器执行的第二个操作是将填充后的音频数组转换为对数-梅尔频谱图。这些频谱图是信号频率的视觉表示,有点像傅里叶变换。图 2 显示了一个示例频谱图。 轴是梅尔通道,对应于特定的频率区间。 轴是时间。每个像素的颜色对应于给定时间该频率区间的对数强度。对数-梅尔频谱图是 Whisper 模型期望的输入形式。
梅尔通道(频率区间)在语音处理中是标准配置,其选择旨在近似人类听觉范围。对于 Whisper 微调,我们只需要知道频谱图是语音信号中频率的视觉表示。有关梅尔通道的更多详细信息,请参阅梅尔频率倒谱。

幸运的是,🤗 Transformers 的 Whisper 特征提取器只需一行代码就能完成填充和频谱图转换!让我们从预训练检查点加载特征提取器,为我们的音频数据做好准备。
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
加载 WhisperTokenizer
现在让我们看看如何加载 Whisper 分词器。Whisper 模型输出的文本词元表示预测文本在词汇表中的索引。分词器将一串文本词元映射到实际的文本字符串(例如 [1169, 3797, 3332] -> "the cat sat")。
传统上,当使用仅编码器模型进行 ASR 时,我们使用连接主义时间分类 (CTC)进行解码。在这里,我们需要为我们使用的每个数据集训练一个 CTC 分词器。使用编码器-解码器架构的优势之一是我们可以直接利用预训练模型的分词器。
Whisper 分词器是在 96 种预训练语言的转录文本上预训练的。因此,它有一个广泛的字节对,适用于几乎所有的多语言 ASR 应用。对于印地语,我们可以加载分词器并将其用于微调,无需任何进一步修改。我们只需指定目标语言和任务。这些参数会通知分词器在编码的标签序列的开头加上语言和任务词元。
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
提示: 通过将上述代码行中的任务设置为
"translate"
,并将语言设置为目标文本语言,本博客文章可以适用于语音翻译。这将在预处理数据集时,为语音翻译任务预置相关的任务和语言词元。
我们可以通过编码和解码 Common Voice 数据集的第一个样本来验证分词器是否能正确编码印地语字符。在对转录文本进行编码时,分词器会在序列的开头和结尾附加“特殊词元”,包括转录开始/结束词元、语言词元和任务词元(如上一步骤中的参数所指定)。在解码标签 ID 时,我们可以选择“跳过”这些特殊词元,从而以原始输入形式返回一个字符串。
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Input: {input_str}")
print(f"Decoded w/ special: {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")
打印输出
Input: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Decoded w/ special: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
Decoded w/out special: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Are equal: True
合并创建 WhisperProcessor
为了简化特征提取器和分词器的使用,我们可以将两者封装到一个单独的 `WhisperProcessor` 类中。这个处理器对象继承自 `WhisperFeatureExtractor` 和 `WhisperProcessor`,可以根据需要用于音频输入和模型预测。这样做,我们在训练期间只需要跟踪两个对象:`processor` 和 `model`。
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
准备数据
让我们打印 Common Voice 数据集的第一个例子,看看数据的形式。
print(common_voice["train"][0])
打印输出
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
1.5334779e-06, 1.0415988e-06], dtype=float32),
'sampling_rate': 48000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
我们可以看到,我们有一个一维的输入音频数组和相应的目标转录文本。我们已经重点讨论了采样率的重要性,以及我们需要将音频的采样率与 Whisper 模型的采样率(16kHz)相匹配。由于我们的输入音频采样率为 48kHz,我们需要在将其传递给 Whisper 特征提取器之前,将其下采样到 16kHz。
我们将使用数据集的 `cast_column` 方法将音频输入设置为正确的采样率。此操作不会就地更改音频,而是向 `datasets` 发出信号,在音频样本首次加载时动态地重新采样。
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
重新加载 Common Voice 数据集中的第一个音频样本会将其重新采样到所需的采样率。
print(common_voice["train"][0])
打印输出
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.4206650e-07, 3.2979898e-07, 1.0042874e-06], dtype=float32),
'sampling_rate': 16000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
太棒了!我们可以看到采样率已经下采样到 16kHz。数组的值也不同了,因为我们现在大约每三个旧值才有一个振幅值。
现在我们可以编写一个函数来为模型准备数据。
- 我们通过调用 `batch["audio"]` 来加载和重采样音频数据。如上所述,🤗 Datasets 会动态执行任何必要的重采样操作。
- 我们使用特征提取器从我们的一维音频数组中计算对数-梅尔频谱图输入特征。
- 我们通过使用分词器将转录文本编码为标签 ID。
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
我们可以使用数据集的 `.map` 方法将数据准备函数应用于我们所有的训练样本。
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
好的!这样我们就为训练做好了充分的数据准备!让我们继续看看如何使用这些数据来微调 Whisper。
注意:目前 `datasets` 同时使用 `torchaudio` 和 `librosa` 进行音频加载和重采样。如果您希望实现自己的定制化数据加载/采样,可以使用 `"path"` 列获取音频文件路径并忽略 `"audio"` 列。
训练与评估
现在我们已经准备好了数据,可以深入研究训练流程了。 🤗 Trainer 将为我们完成大部分繁重的工作。我们只需要做的是:
加载预训练检查点:我们需要加载一个预训练检查点并为其正确配置训练。
定义数据整理器:数据整理器接收我们预处理过的数据,并准备好供模型使用的 PyTorch 张量。
评估指标:在评估期间,我们希望使用词错误率 (WER) 指标来评估模型。我们需要定义一个 `compute_metrics` 函数来处理这个计算。
定义训练参数:这些参数将由 🤗 Trainer 用于构建训练计划。
一旦我们微调了模型,我们将在测试数据上对其进行评估,以验证我们是否已正确地训练它来转录印地语语音。
加载预训练检查点
我们将从预训练的 Whisper `small` 检查点开始我们的微调运行。为此,我们将从 Hugging Face Hub 加载预训练的权重。同样,通过使用 🤗 Transformers,这非常简单!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
在推理时,Whisper 模型会自动检测源音频的语言,并预测该语言的词元 ID。在源音频语言已知的情况下,例如多语言微调,显式设置语言是有益的。这可以避免预测错误语言的情况,从而导致预测文本在生成过程中偏离真实语言。为此,我们将 langauge 和 task 参数设置为生成配置。我们还将任何 `forced_decoder_ids` 设置为 None,因为这是设置语言和任务参数的旧方法。
model.generation_config.language = "hindi"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
定义数据整理器
序列到序列语音模型的数据整理器是独特的,因为它独立处理 `input_features` 和 `labels`:`input_features` 必须由特征提取器处理,而 `labels` 必须由分词器处理。
`input_features` 已经被填充到 30 秒并转换为固定维度的对数-梅尔频谱图,所以我们只需将它们转换为批处理的 PyTorch 张量。我们使用特征提取器的 `.pad` 方法并设置 `return_tensors=pt` 来实现这一点。请注意,这里没有应用额外的填充,因为输入是固定维度的,`input_features` 只是被转换为 PyTorch 张量。
另一方面,`labels` 是未填充的。我们首先使用分词器的 `.pad` 方法将序列填充到批次中的最大长度。然后将填充词元替换为 `-100`,这样在计算损失时就不会考虑这些词元。然后我们从标签序列的开头剪掉转录开始词元,因为我们稍后会在训练中附加它。
我们可以利用我们之前定义的 `WhisperProcessor` 来执行特征提取器和分词器的操作。
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
让我们初始化我们刚刚定义的数据整理器。
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
评估指标
接下来,我们定义将在评估集上使用的评估指标。我们将使用词错误率 (WER) 指标,这是评估 ASR 系统的“事实标准”指标。更多信息,请参阅 WER 文档。我们将从 🤗 Evaluate 加载 WER 指标。
import evaluate
metric = evaluate.load("wer")
然后我们只需定义一个函数,该函数接收我们的模型预测并返回 WER 指标。这个名为 `compute_metrics` 的函数首先在 `label_ids` 中将 `-100` 替换为 `pad_token_id`(撤销我们在数据整理器中为在损失中正确忽略填充词元而应用的步骤)。然后它将预测的 ID 和标签 ID 解码为字符串。最后,它计算预测和参考标签之间的 WER。
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
定义训练参数
在最后一步,我们定义所有与训练相关的参数。下面解释了部分参数:
- `output_dir`:保存模型权重的本地目录。这也将是 Hugging Face Hub 上的仓库名称。
- `generation_max_length`:评估期间自回归生成的最大词元数。
- `save_steps`:训练期间,每 `save_steps` 个训练步骤,中间检查点将被保存并异步上传到 Hub。
- `eval_steps`:训练期间,每 `eval_steps` 个训练步骤,将对中间检查点进行评估。
- `report_to`:保存训练日志的位置。支持的平台有 `"azure_ml"`、`"comet_ml"`、`"mlflow"`、`"neptune"`、`"tensorboard"` 和 `"wandb"`。选择你喜欢的,或者保留为 `"tensorboard"` 以记录到 Hub。
有关其他训练参数的更多详细信息,请参阅 Seq2SeqTrainingArguments 文档。
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
注意:如果不想将模型检查点上传到 Hub,请设置 `push_to_hub=False`。
我们可以将训练参数以及我们的模型、数据集、数据整理器和 `compute_metrics` 函数转发给 🤗 Trainer。
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
这样,我们就可以开始训练了!
训练
要启动训练,只需执行:
trainer.train()
训练将需要大约 5-10 小时,具体取决于您的 GPU 或分配给 Google Colab 的 GPU。根据您的 GPU,开始训练时可能会遇到 CUDA `"out-of-memory"` 错误。在这种情况下,您可以将 `per_device_train_batch_size` 递减 2 的倍数,并使用 `gradient_accumulation_steps` 来补偿。
打印输出
步骤 | 训练损失 | 轮次 | 验证损失 | WER |
---|---|---|---|---|
1000 | 0.1011 | 2.44 | 0.3075 | 34.63 |
2000 | 0.0264 | 4.89 | 0.3558 | 33.13 |
3000 | 0.0025 | 7.33 | 0.4214 | 32.59 |
4000 | 0.0006 | 9.78 | 0.4519 | 32.01 |
5000 | 0.0002 | 12.22 | 0.4679 | 32.10 |
在 4000 个训练步骤后,我们最好的 WER 是 32.0%。作为参考,预训练的 Whisper `small` 模型达到了 63.5% 的 WER,这意味着我们通过微调实现了 31.5% 的绝对改进。对于仅 8 小时的训练数据来说,这相当不错!
现在我们准备在 Hugging Face Hub 上分享我们微调过的模型。为了使其更易于访问,并带有适当的标签和 README 信息,我们可以在推送时设置适当的关键字参数(kwargs)。您可以根据您的数据集、语言和模型名称更改这些值。
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"dataset_args": "config: hi, split: test",
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
}
现在可以将训练结果上传到 Hub。为此,请执行 `push_to_hub` 命令。
trainer.push_to_hub(**kwargs)
现在,您可以使用 Hub 上的链接与任何人分享此模型。他们也可以使用标识符 `"your-username/the-name-you-picked"` 加载它,例如:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")
虽然微调后的模型在 Common Voice 印地语测试数据上取得了令人满意的结果,但它绝不是最优的。本笔记本的目的是演示如何将预训练的 Whisper 检查点微调到任何多语言 ASR 数据集上。通过优化训练超参数,如学习率和丢弃率,并使用更大的预训练检查点(`medium` 或 `large-v3`),结果可能会得到改善。
构建演示
现在我们已经微调了我们的模型,我们可以构建一个演示来展示其 ASR 功能!我们将使用 🤗 Transformers `pipeline`,它将处理整个 ASR 流程,从预处理音频输入到解码模型预测。我们将使用 Gradio 构建我们的交互式演示。Gradio 可以说是构建机器学习演示最直接的方法;使用 Gradio,我们可以在几分钟内构建一个演示!
运行下面的示例将生成一个 Gradio 演示,我们可以通过计算机的麦克风录制语音,并将其输入到我们微调的 Whisper 模型中以转录相应的文本。
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="sanchit-gandhi/whisper-small-hi") # change to "your-username/the-name-you-picked"
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small Hindi",
description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)
iface.launch()
结束语
在本博客中,我们介绍了使用 🤗 Datasets、Transformers 和 Hugging Face Hub 进行多语言 ASR 微调 Whisper 的分步指南。如果您想自己尝试微调,请参考 Google Colab。如果您对微调其他 Transformers 模型(包括英语和多语言 ASR)感兴趣,请务必查看 examples/pytorch/speech-recognition 中的示例脚本。