使用 🤗 Transformers 微调 W2V2-Bert 以支持低资源 ASR

发布于 2024 年 1 月 19 日
在 GitHub 上更新
Open In Colab

最新消息 (2024年1月)本博文深受《在多语言 ASR 上微调 XLS-R》和《为多语言 ASR 微调 MMS 适配器模型》的启发

引言

上个月,MetaAI 发布了 Wav2Vec2-BERT,作为其 Seamless Communication(一个AI翻译模型家族)的构建模块。

Wav2Vec2-BERT 是一系列改进的成果,其基础是原始模型:Wav2Vec2,这是一个用于自动语音识别 (ASR) 的预训练模型,由 Alexei Baevski、Michael Auli 和 Alex Conneau2020年9月发布。仅需 10 分钟的标注音频数据,Wav2Vec2 就可以通过微调在 LibriSpeech 数据集上达到 5% 的词错误率,首次展示了 ASR 的低资源迁移学习能力。

经过一系列多语言改进(XLSRXLS-RMMS),Wav2Vec2-BERT 是一个拥有 5.8 亿参数的多功能音频模型,它在覆盖超过 143 种语言450 万小时无标签音频数据上进行了预训练。相比之下,XLS-R 使用了近 50 万小时的 128 种语言的音频数据,而 MMS 检查点则在超过 1400 种语言50 多万小时音频上进行了预训练。将数据量提升至数百万小时,使得 Wav2Vec2-BERT 能够在与语音相关的任务中,无论何种语言,都能取得更具竞争力的结果。

为了将其用于 ASR,Wav2Vec2-BERT 可以使用连接主义时间分类 (Connectionist Temporal Classification, CTC) 进行微调。CTC 是一种用于训练序列到序列问题(如 ASR 和手写识别)神经网络的算法。我们强烈推荐阅读 Awni Hannun 撰写的优秀博文 Sequence Modeling with CTC (2017),以深入了解 CTC 算法。

本 notebook 的目的是为您提供训练 Wav2Vec2-BERT 模型——更具体地说是预训练检查点 facebook/w2v-bert-2.0——在 ASR 任务上所需的所有要素,全部使用开源工具和模型。它首先介绍了完整的预处理流程,然后对 W2V2-BERT 进行了少量微调。最后一部分汇集了 Hugging Face 专家关于扩展 CTC 训练的技巧。

出于演示目的,我们在 Common Voice 16.0 的低资源蒙古语 ASR 数据集上对模型进行微调,该数据集包含约 14 小时的已验证训练数据。

动机

Whisper 是一套 ASR 模型,被公认为 ASR 任务中表现最佳的模型。它在英语 ASR 方面提供了最先进的性能,同时也非常适合利用有限资源进行多语言微调。

然而,当涉及到像蒙古语这样的“资源贫乏”语言时,Whisper 的表现不佳,正如 Whisper 论文的 D.2.2 节所示——蒙古语或马拉雅拉姆语在每个 Whisper 检查点上的 WER 都超过了 100%。可用的检查点词汇量也有限,因此无法在字母表与该词汇表不重叠的语言上进行微调。

此外,Whisper 是一个序列到序列模型,它以自回归方式执行 ASR,这使其天生就“慢”。对于在训练数据集中不常见的语言,Whisper 的缓慢问题会更加严重。在这种情况下,Whisper 平均每个单词需要生成更多的 token,因此耗时更长。

面对有限的资源——无论是训练数据可用性还是推理限制——需要更“节俭”的模型。在这种情况下,Wav2Vec2-BERT 恰好满足了这一需求。

Wav2Vec2-BERT 通过单次前向传播预测 ASR,使其比 Whisper 快得多。正如本 notebook 将展示的,它需要少量数据即可达到有竞争力的性能易于适应任何字母表,并且资源效率更高

事实上,在经过类似的微调后,它在蒙古语 ASR 上的 WER 性能与 Whisper-large-v3 相当,同时速度快 10 到 30 倍以上,资源效率高 2.5 倍

注意:基准测试是在 Google Colab 上的 16GB V100 GPU 上进行的,在蒙古语 CV16 测试集上使用的批大小从 1 到 8 不等。

Notebook 设置

开始之前,我们先安装 datasetstransformers。此外,我们还需要 accelerate 用于训练,torchaudio 用于加载音频文件,以及 jiwer 用于使用词错误率 (WER) 指标来评估我们微调后的模型。

%%capture
!pip install datasets
!pip install --upgrade transformers
!pip install torchaudio
!pip install jiwer
!pip install accelerate -U

我们强烈建议在训练过程中将您的训练检查点直接上传到 🤗 Hub🤗 Hub 提供:

  • 集成的版本控制:你可以确保在训练过程中不会丢失任何模型检查点。
  • Tensorboard 日志:在训练过程中跟踪重要指标。
  • 模型卡片:记录模型的功能及其预期用途。
  • 社区:一种与社区分享和协作的简便方式!

为此,您需要从 Hugging Face 网站存储您的身份验证令牌(如果还没有,请在此处注册!)。在下方提示时输入您的 Hub 身份验证令牌即可。在此处查找您的 Hub 身份验证令牌

from huggingface_hub import notebook_login

notebook_login()

准备数据、分词器、特征提取器

ASR 模型将语音转写为文本,这意味着我们既需要一个将语音信号处理成模型输入格式(例如特征向量)的特征提取器,也需要一个将模型输出格式处理成文本的分词器。

在 🤗 Transformers 中,Wav2Vec2-BERT 模型因此配备了一个分词器,名为 Wav2Vec2CTCTokenizer,和一个特征提取器,名为 SeamlessM4TFeatureExtractor。该特征提取器与 第一版第二版的 Seamless-M4T 共享,因为它们都以相同的方式处理音频。

让我们从创建分词器开始,用它来将预测的输出类别解码为输出转写文本。

创建 Wav2Vec2CTCTokenizer

请记住,在 CTC 上微调的 Wav2Vec2-like 模型通过单次前向传播来转写音频文件,首先将音频输入处理成一系列处理过的上下文表示,然后使用最终的词汇表输出层将每个上下文表示分类为一个代表转写文本的字符。

该层的输出大小对应于词汇表中的 token 数量,因此只取决于用于微调的带标签数据集。所以第一步,我们将查看选定的 Common Voice 数据集,并根据转写文本定义一个词汇表。

对于本 notebook,我们将使用 Common Voice 16.0 数据集的蒙古语部分。蒙古语对应的语言代码是 "mn"

现在我们可以使用 🤗 Datasets 的简单 API 来下载数据。数据集名称是 "mozilla-foundation/common_voice_16_0",配置名称对应于语言代码,在我们的例子中是 "mn"

注意:在能够下载数据集之前,您必须先登录您的 Hugging Face 账户,访问数据集仓库页面,然后点击“同意并访问仓库”来获取权限。

Common Voice 有许多不同的数据划分,包括 invalidated,指的是那些被评为不够“干净”而无法使用的数据。在本 notebook 中,我们只使用 "train""validation""test" 这几个划分。

因为蒙古语数据集非常小,我们将把验证集和训练集合并为一个训练集,并只使用测试集进行验证。

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="test", use_auth_token=True)

许多 ASR 数据集仅为每个音频数组 'audio' 和文件 'path' 提供目标文本 'sentence'。Common Voice 实际上提供了关于每个音频文件的更多信息,例如 'accent' 等。为了使 notebook 尽可能通用,我们只考虑转写文本进行微调。

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

让我们写一个简短的函数来显示数据集的一些随机样本,并运行几次以感受一下转写文本的特点。

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_train.remove_columns(["path", "audio"]), num_examples=10)

好的!转写文本看起来相当干净。翻译了这些转写句子后,似乎语言更像是书面文本,而不是嘈杂的对话。考虑到 Common Voice 是一个众包的朗读语音语料库,这很合理。

我们可以看到转写文本包含一些特殊字符,比如 ,.?!;:。在没有语言模型的情况下,将语音块分类为这些特殊字符要困难得多,因为它们并不真正对应一个特征性的声音单元。例如,字母 "s" 有一个或多或少清晰的发音,而特殊字符 "." 则没有。此外,为了理解语音信号的含义,通常没有必要在转写中包含特殊字符。

让我们简单地移除所有对词义没有贡献且无法真正用声音表示的字符,并对文本进行规范化。

import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«]'

def remove_special_characters(batch):
    # remove special characters
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()

    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

让我们再看一下处理后的文本标签。

show_random_elements(common_voice_train.remove_columns(["path","audio"]))
Хойч үе юуны төлөө тэмцэлдэхийг би мэдэхгүй.	
Тэр өвдгөн дээрээ толгойгоо тавиад сулхан гиншинэ.	
Эхнэргүй ганц бие хүн гэсэн санагдана.	
Дамиран хотод төрж өссөн хээнцэр залуусын нэг билээ.	
Мөн судлаачид шинжлэх ухааны үндэстэй тайлбар хайдаг.	
Судалгааны ажил нь бүтэлгүй болсонд л гутарч маргааш илүү ажиллах тухай бодсон бололтой.	
Ийм зөрчлөөс гэтлэх гарц "Оноосон нэрийн сан"-г үүсгэснээр шийдвэрлэгдэнэ.	
Үүлтэй тэнгэрийн доогуур үзүүртэй моддын дээгүүр дүүлэн нисэх сэн.	
Та нар ямар юмаа ингэж булаацалдаа вэ?	
Тэд амьд хэлтрээ болов уу яагаа бол гэхээс одоо ч дотор арзганан бачуурдаг юм.	

在 CTC 中,通常将语音块分类为字母,所以我们在这里也这样做。让我们提取训练和测试数据中所有不同的字母,并从这个字母集合构建我们的词汇表。

我们编写一个映射函数,将所有转写文本连接成一个长转写文本,然后将字符串转换为一个字符集合。将参数 batched=True 传递给 map(...) 函数非常重要,这样映射函数就可以一次性访问所有转写文本。

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

现在,我们创建训练集和测试集中所有不同字母的并集,并将结果列表转换为一个带索引的字典。

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'l': 9,
 'n': 10,
 'o': 11,
 'r': 12,
 't': 13,
 'x': 14,
 'а': 15,
 'б': 16,
 'в': 17,
 'г': 18,
 'д': 19,
 'е': 20,
 'ж': 21,
 'з': 22,
 'и': 23,
 'й': 24,
 'к': 25,
 'л': 26,
 'м': 27,
 'н': 28,
 'о': 29,
 'п': 30,
 'р': 31,
 'с': 32,
 'т': 33,
 'у': 34,
 'ф': 35,
 'х': 36,
 'ц': 37,
 'ч': 38,
 'ш': 39,
 'ъ': 40,
 'ы': 41,
 'ь': 42,
 'э': 43,
 'ю': 44,
 'я': 45,
 'ё': 46,
 'ү': 47,
 'ө': 48}

清理数据集是一个需要谨慎进行的反复过程。

查看训练集和测试集中的单个字母,我们发现既有拉丁字母,也有蒙古语西里尔字母。在与一位母语为目标语言的人士(感谢 Mishig 的审阅)讨论后,我们将移除拉丁字母,原因有二:

  1. CTC 算法受益于较小的词汇量,因此建议移除多余的字符。
  2. 在这个例子中,我们完全专注于蒙古语字母表。
def remove_latin_characters(batch):
    batch["sentence"] = re.sub(r'[a-z]+', '', batch["sentence"])
    return batch

# remove latin characters
common_voice_train = common_voice_train.map(remove_latin_characters)
common_voice_test = common_voice_test.map(remove_latin_characters)

# extract unique characters again
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'а': 1,
 'б': 2,
 'в': 3,
 'г': 4,
 'д': 5,
 'е': 6,
 'ж': 7,
 'з': 8,
 'и': 9,
 'й': 10,
 'к': 11,
 'л': 12,
 'м': 13,
 'н': 14,
 'о': 15,
 'п': 16,
 'р': 17,
 'с': 18,
 'т': 19,
 'у': 20,
 'ф': 21,
 'х': 22,
 'ц': 23,
 'ч': 24,
 'ш': 25,
 'ъ': 26,
 'ы': 27,
 'ь': 28,
 'э': 29,
 'ю': 30,
 'я': 31,
 'ё': 32,
 'ү': 33,
 'ө': 34}

太好了,我们看到蒙古语字母表中的所有字母都出现在数据集中(这并不意外),我们还提取了特殊字符 " "。注意,我们没有排除这个特殊字符,因为:模型必须学会预测单词何时结束,否则模型预测将永远是一串字符,无法将单词彼此分开。

应该始终记住,预处理是训练模型前非常重要的一步。例如,我们不希望模型仅仅因为我们忘记了数据规范化而去区分 aAaA 之间的区别根本不取决于字母的“发音”,而更多地取决于语法规则——例如,在句子开头使用大写字母。因此,消除大小写字母之间的差异是明智的,这样模型就能更容易地学习转写语音。您可以在音频 Transformers 课程中阅读更多关于预处理对 ASR 任务影响的内容。

为了更清楚地表明 " " 有自己的 token 类别,我们给它一个更显眼的字符 |。此外,我们还添加了一个“未知”token,以便模型以后可以处理 Common Voice 训练集中未遇到的字符。

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

最后,我们还添加了一个与 CTC 的“空白 token”相对应的填充 token。“空白 token”是 CTC 算法的核心组成部分。更多信息,请参阅这篇博文的“对齐”部分。

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
37

太好了,现在我们的词汇表已经完成,包含 37 个 token,这意味着我们将在预训练的 Wav2Vec2-BERT 检查点之上添加的线性层将具有 37 的输出维度。

现在让我们将词汇表保存为 json 文件。

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

最后一步,我们使用这个 json 文件将词汇表加载到 Wav2Vec2CTCTokenizer 类的一个实例中。

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

如果想要将刚刚创建的分词器与本 notebook 中微调的模型一起重用,强烈建议将 tokenizer 上传到 🤗 Hub。让我们将要上传文件的仓库命名为 "w2v-bert-2.0-mongolian-colab-CV16.0"

repo_name = "w2v-bert-2.0-mongolian-colab-CV16.0"

然后将分词器上传到 🤗 Hub

tokenizer.push_to_hub(repo_name)

太好了,您可以在 https://huggingface.co/<your-username>/w2v-bert-2.0-mongolian-colab-CV16.0 下看到刚刚创建的仓库。

创建 SeamlessM4TFeatureExtractor

SeamlessM4TFeatureExtractor 的作用是将原始音频输入准备成模型能够“理解”的格式。因此,它将一维振幅值序列(即原始音频输入)映射到一个二维的对数梅尔频谱图矩阵。后者将信号的频率信息编码为时间的函数。请参阅音频 Transformers 课程的这一节以了解更多关于频谱图及其重要性的信息。

与分词器不同,特征提取器不需要从数据中“学习”,因此我们可以直接从初始模型检查点加载它。

from transformers import SeamlessM4TFeatureExtractor

feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

太好了,Wav2Vec2-BERT 的特征提取流程就此完全定义好了!

为了提高用户友好性,特征提取器和分词器被封装在一个名为 Wav2Vec2BertProcessor 的类中,这样一来,我们只需要一个 model 和一个 processor 对象即可。

from transformers import Wav2Vec2BertProcessor

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.push_to_hub(repo_name)

接下来,我们可以准备数据集了。

预处理数据

到目前为止,我们还没有查看语音信号的实际数值,只看了转写文本。除了 sentence,我们的数据集还包括另外两个列名 pathaudiopath 指明了音频文件的绝对路径。我们来看看。

common_voice_train[0]["path"]
/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3

Wav2Vec2-BERT 期望输入格式为 16 kHz 的一维数组。这意味着音频文件必须被加载和重采样。

幸运的是,datasets 通过调用另一列 audio 自动完成这个过程。让我们来试试。

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 0.00000000e+00, -1.64773251e-14,  1.81765166e-13, ...,
        -3.23167333e-05,  2.20304846e-05,  3.26883201e-05]),
 'sampling_rate': 48000}

太好了,我们可以看到音频文件已自动加载。这要归功于 datasets == 4.13.3 中引入的新的"Audio"特性,它在调用时会即时加载和重采样音频文件。

在上面的例子中,我们可以看到音频数据以 48kHz 的采样率加载,而 Wav2Vec2-BERT 是在 16kHz 的采样率下进行预训练的。采样率起着重要作用,因为它定义了每秒测量多少语音信号的数据点。因此,以更高的采样率进行采样可以更好地逼近真实的语音信号,但每秒也需要更多的数值。

一个预训练的检查点期望其输入数据与它训练时所用的数据大致来自相同的分布。以两种不同速率采样的相同语音信号具有非常不同的分布,例如,将采样率加倍会导致数据点长度增加一倍。因此,在微调一个 ASR 模型的预训练检查点之前,验证用于预训练模型的数据采样率与用于微调模型的数据集采样率是否匹配至关重要。

幸运的是,我们可以通过使用 cast_column 将音频特征设置为正确的采样率。

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

我们再来看看 "audio"

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 9.09494702e-12, -2.27373675e-13,  5.45696821e-12, ...,
        -5.22854862e-06, -1.21556368e-05, -9.76262163e-06]),
 'sampling_rate': 16000}

这似乎起作用了!让我们听几个音频文件,以更好地了解数据集并验证音频是否已正确加载。

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

看起来数据现在已经正确加载和重采样了。

可以听出,说话者在变化,他们的语速、口音和背景环境等也随之改变。不过,总的来说,录音听起来足够清晰,这对于一个众包的朗读语音语料库来说是意料之中的。

让我们做最后一次检查,确认数据准备是否正确,通过打印语音输入的形状、其转写文本以及相应的采样率。

rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
Target text: энэ бол тэдний амжилтын бодит нууц
Input array shape: (74496,)
Sampling rate: 16000

好的!一切看起来都没问题——数据是一维数组,采样率总是 16kHz,目标文本也已规范化。

最后,我们可以利用 Wav2Vec2BertProcessor 将数据处理成 Wav2Vec2BertForCTC 训练时期望的格式。为此,让我们使用 Dataset 的 map(...) 函数。

首先,我们加载并重采样音频数据,只需调用 batch["audio"]。其次,我们从加载的音频文件中提取 input_features。在我们的例子中,Wav2Vec2BertProcessor 创建了一个比原始波形更复杂的表示,即对数梅尔特征提取。第三,我们将转写文本编码为标签 ID。

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(batch["input_features"])

    batch["labels"] = processor(text=batch["sentence"]).input_ids
    return batch

让我们将数据准备函数应用到所有样本上。

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

注意**:`datasets` 自动处理音频加载和重采样。如果您希望实现自己定制的数据加载/采样,可以随时使用 `“path”` 列,而忽略 `“audio”` 列。

太棒了,现在我们准备好开始训练了!

训练

数据已经处理完毕,我们可以开始设置训练流程了。我们将使用 🤗 Transformer 的 Trainer 类,为此我们主要需要做以下几件事:

  • 定义一个数据整理器。与大多数 NLP 模型不同,Wav2Vec2-BERT 的输入长度远大于输出长度。鉴于输入尺寸较大,动态填充训练批次效率更高,这意味着所有训练样本只应填充到其批次中最长样本的长度,而不是整个数据集中最长样本的长度。因此,微调 Wav2Vec2-BERT 需要一个特殊的填充数据整理器,我们将在下面定义它。

  • 评估指标。在训练期间,模型应以词错误率进行评估。我们应该相应地定义一个 compute_metrics 函数。

  • 加载预训练的检查点。我们需要加载一个预训练的检查点,并为其进行正确的训练配置。

  • 定义训练配置。

在微调模型后,我们将在测试数据上对其进行正确评估,并验证它确实学会了正确转写语音。

设置 Trainer

让我们从定义数据整理器开始。数据整理器的代码是从这个例子中复制的。

不深入太多细节,与常见的数据整理器相比,这个数据整理器对 input_featureslabels 的处理方式不同,因此对它们应用了不同的填充函数。这是必要的,因为在语音中,输入和输出是不同模态的,这意味着它们不应该用相同的填充函数来处理。与常见的数据整理器类似,标签中的填充 token 用 -100 替换,这样这些 token 在计算损失时就会被考虑在内。

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

接下来,定义评估指标。如前所述,ASR 中最主要的指标是词错误率 (WER),因此我们在这个 notebook 中也使用它。

wer_metric = load_metric("wer")

模型将返回一个 logit 向量序列: y1,,ym \mathbf{y}_1, \ldots, \mathbf{y}_m ,其中 y1=fθ(x1,,xn)[0] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] n>>m n >> m

一个 logit 向量 y1 \mathbf{y}_1 包含了我们之前定义的词汇表中每个单词的对数几率,因此 len(yi)= \text{len}(\mathbf{y}_i) = `config.vocab_size`。我们对模型最可能的预测感兴趣,因此取 logits 的 `argmax(...)`。此外,我们通过将 `-100` 替换为 `pad_token_id` 并解码这些 ID,将编码后的标签转换回原始字符串,同时确保连续的 token 在 CTC 风格下**不**被分组为同一个 token1 {}^1

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

现在,我们可以加载主要的预训练检查点了。必须将分词器的 `pad_token_id` 定义为模型的 `pad_token_id`,或者在 `Wav2Vec2BertForCTC` 的情况下,也定义为 CTC 的*空白 token*2 {}^2 。为了节省 GPU 内存,我们启用了 PyTorch 的梯度检查点,并将损失缩减设置为“mean”。

由于我们只训练一小部分权重,模型不容易过拟合。因此,我们确保禁用所有 dropout 层。

注意:当使用此 notebook 在 Common Voice 的另一种语言上训练 Wav2Vec2-BERT 时,这些超参数设置可能效果不佳。请根据您的用例随意调整。

from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

最后一步,我们定义所有与训练相关的参数。对其中一些参数进行更多解释:

  • group_by_length 通过将相似输入长度的训练样本分组到一个批次中,使训练更加高效。这可以通过大量减少通过模型的无用填充 token 的总数来显著加快训练时间。
  • learning_rate 是通过启发式调整得到的,直到微调变得稳定。请注意,这些参数很大程度上取决于 Common Voice 数据集,对于其他语音数据集可能不是最优的。

关于其他参数的更多解释,可以查看文档

在训练期间,每 600 个训练步骤就会有一个检查点异步上传到 Hub。这让您即使在模型仍在训练时,也能试用演示小部件。

注意:如果不想将模型检查点上传到 Hub,只需设置 push_to_hub=False

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=600,
  eval_steps=300,
  logging_steps=300,
  learning_rate=5e-5,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

现在,所有实例都可以传递给 Trainer,我们准备开始训练了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 为了让模型能够独立于说话者的语速,在 CTC 中,连续相同的 token 会被简单地归为一个 token。然而,在解码时,编码后的标签不应该被分组,因为它们不对应于模型的预测 token,这就是为什么必须传递 `group_tokens=False` 参数的原因。如果我们不传递这个参数,像 `“hello”` 这样的词就会被错误地编码和解码为 `“helo”`。 2 {}^2 空白 token 允许模型预测像 `“hello”` 这样的词,通过强制它在两个 l 之间插入空白 token。我们的模型对 `“hello”` 的一个符合 CTC 规范的预测会是 `[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]`。

训练

训练将花费数小时,具体取决于分配给此 notebook 的 GPU。虽然训练后的模型在 Common Voice 的蒙古语测试数据上取得了尚可的结果,但它绝不是一个最优微调的模型。本 notebook 的目的仅仅是演示如何在一个 ASR 数据集上微调 Wav2Vec2-BERT。

trainer.train()
步骤 训练损失 验证损失 词错误率 (Wer)
300 1.712700 0.647740 0.517892
600 0.349300 0.615849 0.442027
900 0.180500 0.525088 0.367305
1200 0.075400 0.528768 0.324016

训练损失和验证 WER 都很好地下降了。相比之下,使用 whisper-large-v3(公认的 OpenAI 最先进的 ASR 模型)进行相同的训练,最终的 WER 为 33.3%。您可以在这里找到最终的 Whisper 检查点。这表明 Wav2Vec2-Bert 在低资源语言上可以达到接近或等同于最先进水平的性能

你现在可以把训练结果上传到 🤗 Hub,只需执行这条指令即可。

trainer.push_to_hub()

你现在可以和所有的朋友、家人、心爱的宠物分享这个模型:他们都可以用“your-username/the-name-you-picked”这个标识符来加载它,例如:

from transformers import AutoModelForCTC, Wav2Vec2BertProcessor

model = AutoModelForCTC.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")
processor = Wav2Vec2BertProcessor.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")

关于如何微调 Wav2Vec2-BERT 的更多示例,请参阅官方语音识别示例

评估

作为最后一次检查,我们加载模型并验证它确实学会了转写蒙古语语音。

我们先加载预训练的检查点。

model = Wav2Vec2BertForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2BertProcessor.from_pretrained(repo_name)

让我们处理音频,进行一次前向传播并预测 ID。

sample = common_voice_test[0]
input_features = torch.tensor(sample["input_features"]).to("cuda").unsqueeze(0)

with torch.no_grad():
    logits = model(input_features).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

最后,我们可以从预测的 token 解码出示例,并与参考转写文本进行比较。

print(processor.decode(pred_ids))
print(processor.decode(sample["labels"]).lower())
эрчүүдийн ганцаардлыг эмэхтэйчүүд ойлгох нь ховор юм
эрчүдийн ганцардлыг эмэгтэйчүд ойлгох нь ховор юм

好的!从我们的预测中绝对可以辨认出转写内容,但还不够完美。将模型训练更长时间,在数据预处理上花费更多时间,尤其是使用语言模型进行解码,肯定会提高模型的整体性能。

然而,对于一个低资源语言的演示模型来说,这个结果已经相当可以接受了 🤗。

扩展训练规模

我们在这篇博文中展示了 Meta 的 w2v-bert-2.0 微调如何在低资源语言上取得接近最先进水平的性能。

为了更进一步,我整理了一套由我在 Hugging Face 的同事们提供的关于如何扩展此模型训练的技巧和要点。这些技巧是在我向他们展示这篇博文的训练运行以及其他训练尝试(这里这里)时浮出水面的。

非常感谢 PatrickSanchitPablo 提供的宝贵专业知识和帮助 🤗

请注意,Common Voice 最新版本 (CV16) 为许多语言提供了更多小时的数据,从而为在许多低资源语言中构建更高效的模型提供了肥沃的土壤。

数据集相关技巧

CTC ASR 通常使用小写、无标点的转写文本。这简化了 CTC 任务,因为模型被视为“纯声学”模型,意味着它的预测主要基于音频的语音声音,而不是口语句子的任何语言模型上下文。

频率极低的字符会通过错误的目标导致损失激增,从而显著影响学习过程中的损失。默认情况下,本博文创建的 CTC 分词器会将它们添加到词汇表中,即使它们的频率与更常见的字符相比可以忽略不计。我们可以将这些字符视为数据集标注中的“错误”,以便将它们从词汇表中移除,并在训练期间简单地分类为 `"[UNK]"`。

因此,绝对有必要重新检查分词器词汇表,并移除所有低频字符,就像我们创建分词器时移除拉丁字符一样。

请注意,Common Voice 数据集特别容易出现这类“错误”字符,例如来自其他语言的字符(阪)。

训练相关技巧

每个 CTC token 看到的平均时长: 通过实验,我们发现每个 CTC token 看到的理想时长比例是 10 到 35 毫秒。换句话说,为了能够正确学习和预测,CTC token 需要看到的声学信息时长既不能太低也不能太高。实际上,它应该大致对应于我们人类发一个音素所需时间的一小部分。

我的一次训练运行的损失曲线最初如预期般平稳下降,但在某个点开始爆炸。我意识到我一直使用的是一个没有架构改动的基本检查点,每个 CTC token 看到的信号时长为 30 到 60 毫秒。添加一个卷积适配器层来对编码器隐藏状态沿时间维度进行子采样,足以将信号块采样减少到期望的时长,并防止这种损失曲线的出现。

训练不足:我的同事们在查看我的训练运行时很快注意到模型严重训练不足,这一点可以从损失曲线上看出来,它看起来像是在陡峭下降的中间被停止了。这也指出了其他问题,特别是损失曲线不够平滑,这是超参数设置不当的迹象。

这里有几种解决我们案例中训练不足的方法:

  • 预热率可能太高,导致学习率下降过快。一个解决方法是保持预热率在 5% 到 15% 之间,并增加训练轮数。预热步骤对于逐渐将新的语言模型头权重与预训练模型对齐至关重要。
  • 损失曲线不平滑的问题可以通过调整 AdamWβ2 \beta_2 来解决,该参数通常可以默认设置为 0.95 到 0.98。

相关文章和附加链接列于此处:

社区

这简直是一篇低水平、复制粘贴、代码写得差的帖子。显然作者在写这些代码时对任何事情都毫无理解。

·

你好 @nicccobb
我看到您觉得这篇帖子和代码有所欠缺。
您能具体指出哪些部分的代码看起来不正确吗?我愿意在需要的地方进行修改并提供更多深度。另外,我已经发现一个明显的错误:我计算每个 CTC token 覆盖的时间跨度(毫秒/token)是错的。
我很乐意听取您认为应该修正或扩充的任何其他观点。建设性的反馈总是受欢迎的——我乐于纠正和学习。

编辑:好文章!谢谢!!

我有一个同样的项目,但是是关于阿拉伯语转写的。我可以遵循同样的步骤吗?

注册登录 以发表评论