为多语言 ASR 微调 MMS 适配器模型

发布于 2023 年 6 月 19 日
在 GitHub 上更新
Open In Colab

新(2023 年 6 月)这篇博文深受《在多语言 ASR 上微调 XLS-R》的启发,可视为其改进版。

Wav2Vec2 是一个用于自动语音识别 (ASR) 的预训练模型,由 Alexei Baevski、Michael Auli 和 Alex Conneau2020 年 9 月发布。Wav2Vec2 在最流行的英文 ASR 数据集 LibriSpeech 上展示出强大性能后不久,Facebook AI 推出了两个 Wav2Vec2 的多语言版本,分别称为 XLSRXLM-R,能够识别多达 128 种语言的语音。XLSR 代表跨语言语音表示,指的是模型学习在多种语言中有用的语音表示的能力。

Meta AI 最新发布的 大规模多语言语音 (MMS),由 Vineel Pratap、Andros Tjandra、Bowen Shi 等人完成,将多语言语音表示提升到新水平。通过发布的各种语言识别、语音识别和文本到语音检查点,可以识别、转录和生成超过 1,100 种口语。

在这篇博文中,我们将展示 MMS 的适配器训练如何在仅 10-20 分钟的微调后,达到惊人的低词错误率。

对于低资源语言,我们强烈建议使用 MMS 的适配器训练,而不是像《在多语言 ASR 上微调 XLS-R》中所做的那样微调整个模型。

在我们的实验中,MMS 的适配器训练在内存效率、鲁棒性方面都更好,并且在低资源语言上产生更好的性能。然而,对于中高资源语言,微调整个检查点而不是使用适配器层可能仍然更有利。

wav2vec2_structure

保护世界语言多样性

根据 https://www.ethnologue.com/ 的数据,约 3000 种,即 40% 的“活语言”,由于母语使用者越来越少而面临濒危。这种趋势在全球化日益加剧的世界中只会继续。

MMS 能够转录许多濒危语言,例如 AriKaivi。未来,MMS 可以通过帮助剩余的使用者创建书面记录并用母语交流,在保持语言活力方面发挥至关重要的作用。

为了适应 1000 多个不同的词汇,MMS 使用了适配器——一种仅训练模型权重一小部分的训练方法。

适配器层就像语言桥梁,使模型能够利用一种语言的知识来破译另一种语言。

MMS 微调

MMS 无监督检查点在超过 50 万小时的音频上进行了预训练,涵盖 1400 多种语言,参数量从 3 亿到 10 亿不等。

您可以在 🤗 Hub 上找到 3 亿参数 (300M) 和 10 亿参数 (1B) 模型大小的仅预训练检查点

注意:如果您想微调基础模型,可以按照“在多语言 ASR 上微调 XLS-R”中所示的完全相同的方式进行操作。

BERT 的掩码语言建模目标类似,MMS 通过在自监督预训练期间将特征向量随机掩码然后将其传递给 Transformer 网络来学习上下文语音表示。

对于 ASR,预训练的 MMS-1B 检查点在 1000 多种语言上通过联合词汇输出层进一步进行了监督微调。作为最后一步,联合词汇输出层被丢弃,并保留了特定于语言的适配器层。每个适配器层仅包含约 2.5M 个权重,由每个注意力块的小型线性投影层以及一个特定于语言的词汇输出层组成。

已发布三个用于语音识别 (ASR) 的MMS微调检查点。它们分别包含 102、1107 和 1162 个适配器权重(每种语言一个)

您可以看到基础模型(像往常一样)保存为 model.safetensors 文件,但此外,这些存储库中还存储了许多适配器权重,例如法国的 adapter.fra.safetensors

Hugging Face 文档详细解释了如何将这些检查点用于推理,因此本篇博文将重点介绍如何基于任何已发布的 ASR 检查点高效训练高性能适配器模型。

训练自适应权重

在机器学习中,适配器是一种用于微调预训练模型的方法,同时保持原始模型参数不变。它们通过在模型的现有层之间插入小的可训练模块(称为适配器层)来实现这一点,然后这些模块使模型适应特定任务,而无需进行大量再训练。

适配器在语音识别,尤其是说话人识别方面有着悠久的历史。在说话人识别中,适配器已有效地用于调整现有模型以识别个体说话人的特殊习惯,正如Gales 和 Woodland (1996)以及Miao 等人 (2014)的工作所强调的那样。与训练完整模型相比,这种方法不仅大大减少了计算需求,而且还允许更好、更灵活的针对说话人的调整。

MMS 中所做的工作利用了适配器在不同语言之间进行语音识别的这一思想。少量适配器权重经过微调,以掌握每种目标语言独特的语音和语法特征。因此,MMS 使得一个大型基础模型(例如,mms-1b-all 检查点)和 1000 多个小型适配器层(mms-1b-all 每个约 2.5M 权重)能够理解和转录多种语言。这大大减少了为每种语言开发独立模型的计算需求。

太棒了!现在我们了解了动机和理论,接下来我们来了解如何微调 mms-1b-all 的适配器权重 🔥

Notebook 设置

正如之前在“在多语言 ASR 上微调 XLS-R”博文中所做的那样,我们将在 Common Voice 的低资源 ASR 数据集上微调模型,该数据集仅包含约 4 小时的验证训练数据。

就像 Wav2Vec2 或 XLS-R 一样,MMS 使用连接主义时间分类 (CTC) 进行微调,CTC 是一种用于训练神经网络解决序列到序列问题(例如 ASR 和手写识别)的算法。

有关 CTC 算法的更多详细信息,我强烈推荐阅读 Awni Hannun 撰写的精彩博文《使用 CTC 进行序列建模 (2017)》

在开始之前,让我们安装 datasetstransformers。此外,我们还需要 torchaudio 来加载音频文件,以及 jiwer 来使用词错误率 (WER) 指标 1 {}^1 评估我们微调过的模型。

%%capture
!pip install --upgrade pip 
!pip install datasets[audio]
!pip install evaluate
!pip install git+https://github.com/huggingface/transformers.git
!pip install jiwer
!pip install accelerate

我们强烈建议在训练期间将您的训练检查点直接上传到 🤗 Hub。Hub 仓库内置了版本控制,因此您可以确保在训练期间不会丢失任何模型检查点。

为此,您必须存储来自 Hugging Face 网站的身份验证令牌(如果您尚未注册,请在此处注册!)。

from huggingface_hub import notebook_login

notebook_login()

准备数据、分词器、特征提取器

ASR 模型将语音转写为文本,这意味着我们既需要一个将语音信号处理成模型输入格式(例如特征向量)的特征提取器,也需要一个将模型输出格式处理成文本的分词器。

在 🤗 Transformers 中,MMS 模型因此附带了一个特征提取器,称为 Wav2Vec2FeatureExtractor,以及一个分词器,称为 Wav2Vec2CTCTokenizer

让我们从创建分词器开始,用它来将预测的输出类别解码为输出转写文本。

创建 Wav2Vec2CTCTokenizer

经过微调的 MMS 模型,例如 mms-1b-all,已经有一个随模型检查点附带的分词器。然而,由于我们希望在特定低资源语言的数据上微调模型,建议完全移除该分词器和词汇输出层,并简单地根据训练数据本身创建新的。

在 CTC 上微调的 Wav2Vec2 类模型通过一次前向传播转录音频文件,首先将音频输入处理成一系列处理后的上下文表示,然后使用最终的词汇输出层将每个上下文表示分类为一个字符,该字符表示转录结果。

此层的输出大小对应于我们之前定义的词汇中的标记数量,我们将从用于微调的标记数据集中提取这些标记。因此,第一步,我们将查看 Common Voice 中选定的数据集,并根据转录定义一个词汇表。

在本笔记中,我们将使用 Common Voice 的 6.1 版数据集,用于土耳其语。土耳其语对应的语言代码为 "tr"

太棒了,现在我们可以使用 🤗 Datasets 的简单 API 下载数据。数据集名称是 "mozilla-foundation/common_voice_6_1",配置名称对应于语言代码,在我们的例子中是 "tr"

注意:在能够下载数据集之前,您必须通过登录您的 Hugging Face 账户,访问数据集仓库页面并点击“同意并访问仓库”来访问它

Common Voice 有许多不同的拆分,包括 invalidated,指的是未被评为“足够干净”而无法被认为有用的数据。在本笔记中,我们将仅使用 "train""validation""test" 拆分。

因为土耳其语数据集很小,我们将验证数据和训练数据合并成一个训练数据集,并且只使用测试数据进行验证。

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("mozilla-foundation/common_voice_6_1", "tr", split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_6_1", "tr", split="test", use_auth_token=True)

许多 ASR 数据集只为每个音频数组 ('audio') 和文件 ('path') 提供目标文本 ('sentence')。Common Voice 实际上提供了关于每个音频文件的更多信息,例如 'accent' 等。为了使笔记本尽可能通用,我们只考虑转录文本进行微调。

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

让我们写一个简短的函数来显示数据集的一些随机样本,并运行几次以感受一下转写文本的特点。

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))
show_random_elements(common_voice_train.remove_columns(["path", "audio"]), num_examples=10)
Oylar teker teker elle sayılacak.
Son olaylar endişe seviyesini yükseltti.
Tek bir kart hepsinin kapılarını açıyor.
Blogcular da tam bundan bahsetmek istiyor.
Bu Aralık iki bin onda oldu.
Fiyatın altmış altı milyon avro olduğu bildirildi.
Ardından da silahlı çatışmalar çıktı.
"Romanya'da kurumlar gelir vergisi oranı yüzde on altı."
Bu konuda neden bu kadar az şey söylendiğini açıklayabilir misiniz?

好的!转录看起来相当干净。在翻译了转录的句子后,似乎该语言更像书面文本而不是嘈杂的对话。考虑到Common Voice是一个众包阅读语音语料库,这很有道理。

我们可以看到转录中包含一些特殊字符,例如 ,.?!;:。如果没有语言模型,将语音块分类到这些特殊字符会困难得多,因为它们并不真正对应于特征性声音单元。例如,字母 "s" 有一个或多或少清晰的声音,而特殊字符 "." 则没有。此外,为了理解语音信号的含义,通常不需要在转录中包含特殊字符。

让我们简单地移除所有对词义没有贡献且无法真正用声音表示的字符,并对文本进行规范化。

import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

让我们再看一下处理后的文本标签。

show_random_elements(common_voice_train.remove_columns(["path","audio"]))
i̇kinci tur müzakereler eylül ayında başlayacak
jani ve babası bu düşüncelerinde yalnız değil
onurun gözlerindeki büyü
bandiç oyların yüzde kırk sekiz virgül elli dördünü topladı
bu imkansız
bu konu açık değildir
cinayet kamuoyunu şiddetle sarstı
kentin sokakları iki metre su altında kaldı
muhalefet partileri hükümete karşı ciddi bir mücadele ortaya koyabiliyorlar mı
festivale tüm dünyadan elli film katılıyor

很好!这看起来好多了。我们已经从转录中删除了大部分特殊字符,并将其规范化为全小写。

在最终确定预处理之前,咨询目标语言的母语使用者以查看文本是否可以进一步简化总是很有利的。对于这篇博客文章,Merve 慷慨地快速查看了一下,并指出“带帽”字符——例如 â——在土耳其语中已不再使用,可以替换为它们的“不带帽”等价物,例如 a

这意味着我们应该将像 "yargı sistemi hâlâ sağlıksız" 这样的句子替换为 "yargı sistemi hala sağlıksız"

让我们再写一个简短的映射函数来进一步简化文本标签。

def replace_hatted_characters(batch):
    batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
    batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
    batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
    batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
    return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)

在 CTC 中,通常将语音块分类为字母,所以我们在这里也这样做。让我们提取训练和测试数据中所有不同的字母,并从这个字母集合构建我们的词汇表。

我们编写了一个映射函数,将所有转录连接成一个长转录,然后将字符串转换为一组字符。重要的是将参数 batched=True 传递给 map(...) 函数,以便映射函数可以一次访问所有转录。

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

现在,我们创建训练集和测试集中所有不同字母的并集,并将结果列表转换为一个带索引的字典。

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
    {' ': 0,
     'a': 1,
     'b': 2,
     'c': 3,
     'd': 4,
     'e': 5,
     'f': 6,
     'g': 7,
     'h': 8,
     'i': 9,
     'j': 10,
     'k': 11,
     'l': 12,
     'm': 13,
     'n': 14,
     'o': 15,
     'p': 16,
     'q': 17,
     'r': 18,
     's': 19,
     't': 20,
     'u': 21,
     'v': 22,
     'w': 23,
     'x': 24,
     'y': 25,
     'z': 26,
     'ç': 27,
     'ë': 28,
     'ö': 29,
     'ü': 30,
     'ğ': 31,
     'ı': 32,
     'ş': 33,
     '̇': 34}

酷,我们看到字母表中的所有字母都出现在数据集中(这并不奇怪),并且我们也提取了特殊字符 ""'。请注意,我们没有排除这些特殊字符,因为模型必须学会预测何时一个单词结束,否则预测将始终是一串字母,这将使单词之间无法分离。

应该始终牢记,预处理是训练模型之前非常重要的一步。例如,我们不希望模型仅仅因为我们忘记规范化数据而区分 aAaA 之间的区别根本不取决于字母的“发音”,而更多地取决于语法规则——例如,在句首使用大写字母。因此,消除大写字母和非大写字母之间的区别是明智的,这样模型更容易学习转录语音。

为了更清楚地表明 " " 有其自己的标记类别,我们给它一个更明显的字符 |。此外,我们还添加了一个“未知”标记,以便模型以后可以处理 Common Voice 训练集中未遇到的字符。

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

最后,我们还添加了一个与 CTC 的“空白标记”相对应的填充标记。“空白标记”是 CTC 算法的核心组成部分。有关更多信息,请参阅此处的“对齐”部分。

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
    37

酷,现在我们的词汇表已完成,包含 37 个标记,这意味着我们将作为适配器权重的一部分添加到预训练 MMS 检查点顶部的线性层将具有 37 的输出维度。

由于单个 MMS 检查点可以为多种语言提供定制的权重,因此分词器也可以由多个词汇表组成。因此,我们需要嵌套我们的 vocab_dict,以便将来可以向词汇表中添加更多语言。该字典应与用于适配器权重的名称进行嵌套,并且该名称保存在分词器配置中,名称为 target_lang

让我们使用 ISO-639-3 语言代码,就像原始的 mms-1b-all 检查点一样。

target_lang = "tur"

让我们定义一个空字典,我们可以将刚刚创建的词汇表添加到其中

new_vocab_dict = {target_lang: vocab_dict}

注意:如果您想使用此笔记本将新的适配器层添加到现有模型仓库,请务必不要创建空的、新的词汇字典,而是重新使用已存在的词汇字典。为此,您应该取消注释以下单元格,并将 "patrickvonplaten/wav2vec2-large-mms-1b-turkish-colab" 替换为您要添加适配器权重的模型仓库 ID。

# from transformers import Wav2Vec2CTCTokenizer

# mms_adapter_repo = "patrickvonplaten/wav2vec2-large-mms-1b-turkish-colab"  # make sure to replace this path with a repo to which you want to add your new adapter weights

# tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(mms_adapter_repo)
# new_vocab = tokenizer.vocab

# new_vocab[target_lang] = vocab_dict

现在让我们将词汇表保存为 json 文件。

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(new_vocab_dict, vocab_file)

最后一步,我们使用 json 文件将词汇表加载到 Wav2Vec2CTCTokenizer 类的实例中。

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", target_lang=target_lang)

如果想将刚刚创建的分词器与本笔记本中微调过的模型一起重复使用,强烈建议将 tokenizer 上传到 🤗 Hub。我们把要上传文件的仓库命名为 "wav2vec2-large-mms-1b-turkish-colab"

repo_name = "wav2vec2-large-mms-1b-turkish-colab"

然后将分词器上传到 🤗 Hub

tokenizer.push_to_hub(repo_name)
    CommitInfo(commit_url='https://huggingface.co/patrickvonplaten/wav2vec2-large-mms-1b-turkish-colab/commit/48cccbfd6059aa6ce655e9d94b8358ba39536cb7', commit_message='Upload tokenizer', commit_description='', oid='48cccbfd6059aa6ce655e9d94b8358ba39536cb7', pr_url=None, pr_revision=None, pr_num=None)

太棒了,您可以在 https://huggingface.co/<您的用户名>/wav2vec2-large-mms-1b-tr-colab 下看到刚刚创建的存储库

创建 Wav2Vec2FeatureExtractor

语音是一种连续信号,要由计算机处理,它首先必须离散化,这通常称为采样。采样率在此处扮演重要角色,因为它定义了每秒测量多少语音信号数据点。因此,以更高的采样率采样会更好地逼近真实语音信号,但每秒也需要更多值。

预训练的检查点期望其输入数据以与训练时使用的数据大致相同的分布进行采样。以两种不同速率采样的相同语音信号具有非常不同的分布,例如,采样率加倍会导致数据点数量加倍。因此,在微调 ASR 模型的预训练检查点之前,验证用于预训练模型的数据采样率是否与用于微调模型的数据集的采样率匹配至关重要。

一个 Wav2Vec2FeatureExtractor 对象需要以下参数才能实例化

  • feature_size: 语音模型将特征向量序列作为输入。虽然这个序列的长度显然是变化的,但特征大小不应该变化。在 Wav2Vec2 的情况下,特征大小为 1,因为模型是在原始语音信号上训练的 2 {}^2
  • sampling_rate: 模型训练时使用的采样率。
  • padding_value:对于批量推理,较短的输入需要用特定值填充。
  • do_normalize:输入是否应该进行零均值单位方差归一化。通常,语音模型在归一化输入后表现更好。
  • return_attention_mask: 模型是否应使用 attention_mask 进行批量推理。一般来说,XLS-R 模型检查点应该始终使用 attention_mask
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

太棒了,MMS 的特征提取管道由此完全定义!

为了提高用户友好性,特征提取器和分词器被封装在一个 Wav2Vec2Processor 类中,这样只需要一个 modelprocessor 对象。

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

接下来,我们可以准备数据集了。

预处理数据

到目前为止,我们只查看了语音信号的转录,而没有查看实际值。除了 sentence 之外,我们的数据集中还包含另外两个列名 pathaudiopath 表示音频文件的绝对路径,audio 表示已加载的音频数据。MMS 期望输入格式为 16 kHz 的一维数组。这意味着必须加载并重新采样音频文件。

值得庆幸的是,当列名为 audio 时,datasets 会自动完成此操作。我们来试一下。

common_voice_train[0]["audio"]
    {'path': '/root/.cache/huggingface/datasets/downloads/extracted/71ba9bd154da9d8c769b736301417178729d2b87b9e00cda59f6450f742ed778/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_17346025.mp3',
     'array': array([ 0.00000000e+00, -2.98378618e-13, -1.59835903e-13, ...,
            -2.01663317e-12, -1.87991593e-12, -1.17969588e-12]),
     'sampling_rate': 48000}

在上面的示例中,我们可以看到音频数据以 48kHz 的采样率加载,而模型期望的采样率为 16kHz。我们可以通过使用 cast_column 将音频特征设置为正确的采样率。

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

我们再来看看 "audio"

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/71ba9bd154da9d8c769b736301417178729d2b87b9e00cda59f6450f742ed778/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_17346025.mp3',
 'array': array([ 9.09494702e-13, -6.13908924e-12, -1.09139364e-11, ...,
         1.81898940e-12,  4.54747351e-13,  3.63797881e-12]),
 'sampling_rate': 16000}

这似乎奏效了!让我们最后检查一下数据是否正确准备,打印出语音输入的形状、其转录以及相应的采样率。

rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
    Target text: bağış anlaşması bir ağustosta imzalandı
    Input array shape: (70656,)
    Sampling rate: 16000

好的!一切看起来都没问题——数据是一维数组,采样率总是 16kHz,目标文本也已规范化。

最后,我们可以利用 Wav2Vec2Processor 将数据处理成 Wav2Vec2ForCTC 训练所需的格式。为此,我们使用 Dataset 的 map(...) 函数。

首先,我们通过调用 batch["audio"] 来加载并重新采样音频数据。其次,我们从加载的音频文件中提取 input_values。在我们的例子中,Wav2Vec2Processor 仅对数据进行归一化。然而,对于其他语音模型,此步骤可能包含更复杂的特征提取,例如对数梅尔特征提取。第三,我们将转录编码为标签 ID。

注意:这个映射函数是 Wav2Vec2Processor 类应如何使用的很好的例子。在“正常”情况下,调用 processor(...) 会被重定向到 Wav2Vec2FeatureExtractor 的调用方法。然而,当将处理器封装到 as_target_processor 上下文中时,相同的方法会被重定向到 Wav2Vec2CTCTokenizer 的调用方法。有关更多信息,请查阅文档

def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])

    batch["labels"] = processor(text=batch["sentence"]).input_ids
    return batch

让我们将数据准备函数应用到所有样本上。

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

注意datasets 会自动处理音频加载和重采样。如果您希望实现自己的自定义数据加载/采样,请随意只使用 "path" 列,并忽略 "audio" 列。

太棒了,现在我们准备好开始训练了!

训练

数据已处理完毕,我们已准备好开始设置训练管线。我们将使用 🤗 的 Trainer,为此我们主要需要执行以下操作

  • 定义一个数据收集器。与大多数 NLP 模型不同,MMS 的输入长度远大于输出长度。例如,输入长度为 50000 的样本的输出长度不超过 100。鉴于输入尺寸较大,动态填充训练批次效率更高,这意味着所有训练样本都应仅填充到其批次中最长的样本,而不是整体最长的样本。因此,微调 MMS 需要一个特殊的填充数据收集器,我们将在下面定义。

  • 评估指标。在训练期间,模型应以词错误率进行评估。我们应该相应地定义一个 compute_metrics 函数。

  • 加载预训练检查点。我们需要加载预训练检查点并对其进行正确配置以进行训练。

  • 定义训练配置。

在微调模型后,我们将在测试数据上对其进行正确评估,并验证它确实学会了正确转写语音。

设置训练器

我们从定义数据收集器开始。数据收集器的代码复制自这个示例

不赘述过多细节,与常见数据收集器不同,此数据收集器对 input_valueslabels 进行不同处理,因此对它们应用了两个独立的填充函数(再次利用 MMS 处理器的上下文管理器)。这是必要的,因为在语音识别中,输入和输出是不同的模态,因此它们不应由相同的填充函数处理。与常见数据收集器类似,标签中的填充标记用 -100 填充,以便在计算损失时考虑这些标记。

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

接下来,定义评估指标。如前所述,ASR 中最主要的指标是词错误率 (WER),因此我们在这个 notebook 中也使用它。

from evaluate import load

wer_metric = load("wer")

模型将返回一系列 logits 向量:y1,,ym \mathbf{y}_1, \ldots, \mathbf{y}_m ,其中 y1=fθ(x1,,xn)[0] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] n>>m n >> m

一个 logit 向量 y1 \mathbf{y}_1 包含我们之前定义的词汇中每个单词的对数几率,因此 len(yi)= \text{len}(\mathbf{y}_i) = config.vocab_size。我们对模型最可能的预测感兴趣,因此取 logits 的 argmax(...)。此外,我们通过将 -100 替换为 pad_token_id 并解码 ID 将编码标签转换回原始字符串,同时确保连续标记以 CTC 样式分组 1 {}^1

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

现在,我们可以加载 mms-1b-all 的预训练检查点。分词器的 pad_token_id 必须用于定义模型的 pad_token_id,或者在 Wav2Vec2ForCTC 的情况下,也用于 CTC 的空白标记 2 {}^2

由于我们只训练一小部分权重,模型不易过拟合。因此,我们确保禁用所有 dropout 层。

注意:当使用此笔记本在 Common Voice 的另一种语言上训练 MMS 时,这些超参数设置可能效果不佳。请根据您的用例随意调整。

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/mms-1b-all",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)
    Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:
    - lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([39]) in the model instantiated
    - lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([39, 1280]) in the model instantiated
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

注意:一些权重预计会重新初始化。这些权重对应于新初始化的词汇输出层。

现在我们要确保只训练适配器权重,而模型的其余部分保持冻结。

首先,我们重新初始化所有适配器权重,这可以通过方便的 init_adapter_layers 方法完成。也可以不重新初始化适配器权重并继续微调,但在这种情况下,在训练之前应确保通过 load_adapter(...) 方法加载合适的适配器权重。然而,通常词汇表与自定义训练数据仍然不是很匹配,因此通常更容易直接重新初始化所有适配器层,以便它们可以轻松微调。

model.init_adapter_layers()

接下来,我们冻结所有权重,除了适配器层。

model.freeze_base_model()

adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

最后一步,我们定义所有与训练相关的参数。对其中一些参数进行更多解释:

  • group_by_length 通过将输入长度相似的训练样本分组到一个批次中,使训练更高效。这可以显著缩短训练时间,通过大大减少模型中传递的无用填充标记的总数。
  • learning_rate 选择为 1e-3,这是使用 Adam 训练时的常见默认值。其他学习率可能同样有效。

有关其他参数的更多解释,可以查看文档。为了节省 GPU 内存,我们启用 PyTorch 的梯度检查点,并将损失减少设置为“均值”。MMS 适配器微调以极快的速度收敛到非常好的性能,因此即使对于像 4 小时这样小的数据集,我们也只训练 4 个 epoch。训练期间,每 200 个训练步骤将异步上传一个检查点到 Hub。它还允许您在模型仍在训练时使用演示小部件。

注意:如果不想将模型检查点上传到 Hub,只需设置 push_to_hub=False

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=4,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=200,
  eval_steps=100,
  logging_steps=100,
  learning_rate=1e-3,
  warmup_steps=100,
  save_total_limit=2,
  push_to_hub=True,
)

现在,所有实例都可以传递给 Trainer,我们准备开始训练了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 为了使模型独立于说话人语速,在 CTC 中,连续的相同标记被简单地分组为一个标记。然而,在解码时,编码的标签不应分组,因为它们不对应于模型的预测标记,这就是为什么必须传递 group_tokens=False 参数。如果我们不传递此参数,像 "hello" 这样的单词将被错误地编码并解码为 "helo"2 {}^2 空白标记允许模型预测一个单词,例如 "hello",通过强制它在两个“l”之间插入空白标记。我们模型对 "hello" 的符合 CTC 的预测将是 [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]

训练

训练时间应少于 30 分钟,具体取决于所使用的 GPU。

trainer.train()
训练损失 训练步骤 验证损失 词错误率 (Wer)
4.905 100 0.215 0.280
0.290 200 0.167 0.232
0.2659 300 0.161 0.229
0.2398 400 0.156 0.223

训练损失和验证 WER 都在稳步下降。

我们发现,仅对 mms-1b-all 的适配器层进行 100 步的微调,就已大幅超越了此处所示的整个 xls-r-300m 检查点微调的性能。

官方论文和这个快速比较可以清楚看出,mms-1b-all 具有更高的知识迁移到低资源语言的能力,应优于 xls-r-300m。此外,由于只训练了一小部分层,训练效率也更高,内存占用更少。

适配器权重将作为模型检查点的一部分上传,但我们也要确保单独保存它们,以便可以轻松地进行卸载和加载。

让我们将所有适配器层保存到训练输出目录中,以便将其正确上传到 Hub。

from safetensors.torch import save_file as safe_save_file
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
import os

adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
adapter_file = os.path.join(training_args.output_dir, adapter_file)

safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})

最后,您可以将训练结果上传到 🤗 Hub。

trainer.push_to_hub()

适配器权重训练的主要优点之一是,“基础”模型(约占模型权重的 99%)保持不变,只需共享一个小型2.5M 适配器检查点即可使用训练好的检查点。

这使得训练额外的适配器层并将其添加到您的仓库变得异常简单。

您可以通过简单地重新运行此脚本,并将其更改为您想要训练的不同语言(例如瑞典语的 swe)来轻松实现这一点。此外,您应该确保词汇表不会被完全覆盖,而是将新语言词汇表附加到现有词汇表中,如上述注释掉的单元格中所述。

为了演示如何加载不同的适配器层,我还训练并上传了一个瑞典语的适配器层,其 ISO 语言代码为 swe,您可以在这里看到。

您可以通过 from_pretrained(...) 像往常一样加载微调后的检查点,但您应确保向该方法添加 target_lang="<您的语言代码>",以便加载正确的适配器。您还应为您的分词器正确设置目标语言。

我们先来看看如何加载土耳其语检查点。

model_id = "patrickvonplaten/wav2vec2-large-mms-1b-turkish-colab"

model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="tur").to("cuda")
processor = Wav2Vec2Processor.from_pretrained(model_id)

processor.tokenizer.set_target_lang("tur")

让我们检查模型是否能正确转录土耳其语

from datasets import Audio

common_voice_test_tr = load_dataset("mozilla-foundation/common_voice_6_1", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test", use_auth_token=True)
common_voice_test_tr = common_voice_test_tr.cast_column("audio", Audio(sampling_rate=16_000))

我们处理音频,进行前向传播并预测 ID

input_dict = processor(common_voice_test_tr[0]["audio"]["array"], sampling_rate=16_000, return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

最后,我们可以解码示例。

print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(common_voice_test_tr[0]["sentence"].lower())

输出:

    Prediction:
    pekçoğuda roman toplumundan geliyor

    Reference:
    pek çoğu da roman toplumundan geliyor.

这看起来几乎完全正确,只是第一个单词应该多加两个空字符。现在,通过调用model.load_adapter(...) 并将分词器也更改为瑞典语,可以非常简单地将适配器更改为瑞典语。

model.load_adapter("swe")
processor.tokenizer.set_target_lang("swe")

我们再次从 Common Voice 加载瑞典语测试集

common_voice_test_swe = load_dataset("mozilla-foundation/common_voice_6_1", "sv-SE", data_dir="./cv-corpus-6.1-2020-12-11", split="test", use_auth_token=True)
common_voice_test_swe = common_voice_test_swe.cast_column("audio", Audio(sampling_rate=16_000))

并转录一个样本

input_dict = processor(common_voice_test_swe[0]["audio"]["array"], sampling_rate=16_000, return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(common_voice_test_swe[0]["sentence"].lower())

输出:

    Prediction:
    jag lämnade grovjobbet åt honom

    Reference:
    jag lämnade grovjobbet åt honom.

太棒了,这看起来是完美的转录!

我们在这篇博文中展示了 MMS 适配器权重微调不仅在低资源语言上提供了最先进的性能,而且还显著加快了训练时间,并允许轻松构建定制适配器权重集合。

相关文章和附加链接列于此

社区

注册登录 发表评论