使用 🤗 Transformers 微调 XLS-R 进行多语言 ASR
新 (11/2021):此博客文章已更新,以介绍 XLSR 的继任者,名为 XLS-R。
Wav2Vec2 是用于自动语音识别 (ASR) 的预训练模型,由 Alexei Baevski、Michael Auli 和 Alex Conneau 于 2020 年 9 月 发布。在 Wav2Vec2 在最流行的英语 ASR 数据集 LibriSpeech 上展示出卓越性能后不久,Facebook AI 发布了 Wav2Vec2 的多语言版本,称为 XLSR。XLSR 代表跨语言语音表示,指的是模型学习在多种语言中有用的语音表示的能力。
XLSR 的继任者,简称为 XLS-R(指的是“XLM-R 用于语音”),由 Arun Babu、Changhan Wang、Andros Tjandra 等人于 2021 年 11 月 发布。XLS-R 使用了近 50 万小时的 128 种语言的音频数据进行自监督预训练,其大小从 3 亿参数到 20 亿参数不等。您可以在 🤗 Hub 上找到预训练检查点。
与 BERT 的掩码语言建模目标类似,XLS-R 通过在自监督预训练期间(即下方左图)将特征向量随机掩码,然后将其传递给 Transformer 网络来学习上下文语音表示。
对于微调,在预训练网络之上添加一个线性层,以便在音频下游任务(如语音识别、语音翻译和音频分类)的带标签数据上训练模型(即下方右图)。
XLS-R 在语音识别、语音翻译和说话人/语言识别方面均表现出比以前最先进结果显著的改进,参见官方论文中的表 3-6、表 7-10 和表 11-12。
设置
在本博客中,我们将详细解释如何微调 XLS-R——更具体地说是预训练检查点 Wav2Vec2-XLS-R-300M——用于 ASR。
为了演示目的,我们将模型在 Common Voice 的低资源 ASR 数据集上进行微调,该数据集仅包含约 4 小时已验证的训练数据。
XLS-R 使用连接主义时间分类 (CTC) 进行微调,这是一种用于训练序列到序列问题(如 ASR 和手写识别)神经网络的算法。
我强烈推荐阅读 Awni Hannun 撰写的精彩博客文章 Sequence Modeling with CTC (2017)。
在开始之前,我们先安装 datasets
和 transformers
。此外,我们还需要 torchaudio
来加载音频文件,以及 jiwer
来使用 词错误率 (WER) 指标 评估我们微调的模型。
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer
我们强烈建议您在训练期间将训练检查点直接上传到 Hugging Face Hub。Hugging Face Hub 集成了版本控制,因此您可以确保在训练期间不会丢失任何模型检查点。
为此,您必须存储来自 Hugging Face 网站的身份验证令牌(如果您尚未注册,请在此处注册!)。
from huggingface_hub import notebook_login
notebook_login()
打印输出
Login successful
Your token has been saved to /root/.huggingface/token
然后您需要安装 Git-LFS 才能上传您的模型检查点
apt install git-lfs
在 论文中,模型使用音素错误率 (PER) 进行评估,但目前 ASR 中最常见的指标是词错误率 (WER)。为了使本笔记本尽可能通用,我们决定使用 WER 评估模型。
准备数据、分词器、特征提取器
ASR 模型将语音转写为文本,这意味着我们既需要一个将语音信号处理成模型输入格式(例如特征向量)的特征提取器,也需要一个将模型输出格式处理成文本的分词器。
在 🤗 Transformers 中,XLS-R 模型因此配备了分词器 Wav2Vec2CTCTokenizer 和特征提取器 Wav2Vec2FeatureExtractor。
让我们从创建分词器开始,用它来将预测的输出类别解码为输出转写文本。
创建 Wav2Vec2CTCTokenizer
预训练的 XLS-R 模型将语音信号映射到上下文表示序列,如上图所示。然而,对于语音识别,模型必须将此上下文表示序列映射到其对应的转录,这意味着必须在 Transformer 块之上添加一个线性层(上图黄色部分所示)。此线性层用于将每个上下文表示分类为一个标记类别,类似于在 BERT 的嵌入之上添加线性层以在预训练后进行进一步分类的方式(参见以下博客文章的“BERT”部分)。在预训练之后,在 BERT 的嵌入之上添加一个线性层以进行进一步分类——参见此博客文章的“BERT”部分。
此层的输出大小对应于词汇表中的标记数量,这**不**取决于 XLS-R 的预训练任务,而仅取决于用于微调的带标签数据集。因此,第一步,我们将查看选择的 Common Voice 数据集,并根据转录定义一个词汇表。
首先,我们前往 Common Voice 官方网站并选择一种语言来微调 XLS-R。在本笔记本中,我们将使用土耳其语。
对于每个特定语言的数据集,您可以找到与您所选语言对应的语言代码。在 Common Voice 上,查找“版本”字段。语言代码对应于下划线之前的前缀。例如,土耳其语的语言代码是 "tr"
。
很好!现在我们可以使用 🤗 Datasets 简单的 API 来下载数据。数据集名称是 "common_voice"
,配置名称对应于语言代码,在本例中是 "tr"
。
Common Voice 有许多不同的拆分,包括 invalidated
,它指的是未被评为“足够清晰”而无法被认为有用的数据。在本笔记本中,我们只使用 "train"
、"validation"
和 "test"
拆分。
由于土耳其语数据集很小,我们将验证数据和训练数据合并到一个训练数据集中,只使用测试数据进行验证。
from datasets import load_dataset, load_metric, Audio
common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")
许多 ASR 数据集仅提供每个音频数组 'audio'
和文件 'path'
的目标文本 'sentence'
。Common Voice 实际上提供了每个音频文件的更多信息,例如 'accent'
等。为了使笔记本尽可能通用,我们只考虑转录文本进行微调。
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
让我们写一个简短的函数来显示数据集的一些随机样本,并运行几次以感受一下转写文本的特点。
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
打印输出
索引 | 句子 |
---|---|
1 | Jonuz是唯一接受短期任务的候选人。 |
2 | 我们从这场斗争中获得希望。 |
3 | 展览中展示了五项克罗地亚创新。 |
4 | 万物皆有其名。 |
5 | 该机构已准备好私有化。 |
6 | 定居点的景色很美。 |
7 | 事件的肇事者未能找到。 |
8 | 然而,这些努力都白费了。 |
9 | 该项目价值2.77百万欧元。 |
10 | 大型重建项目分为四个阶段。 |
好的!转录看起来相当清晰。翻译这些转录的句子后,似乎这种语言更像书面文本而不是嘈杂的对话。这很合理,考虑到 Common Voice 是一个众包的朗读语音语料库。
我们可以看到转录包含一些特殊字符,例如 ,.?!;:
。在没有语言模型的情况下,将语音片段分类为这些特殊字符要困难得多,因为它们与特征声音单元没有真正的对应关系。例如,字母 "s"
有一个或多或少清晰的声音,而特殊字符 "."
则没有。此外,为了理解语音信号的含义,通常不需要在转录中包含特殊字符。
让我们简单地移除所有对词义没有贡献且无法真正用声音表示的字符,并对文本进行规范化。
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
让我们再看一下处理后的文本标签。
show_random_elements(common_voice_train.remove_columns(["path","audio"]))
打印输出
索引 | 转录 |
---|---|
1 | 他们说其中一个是为白人而战 |
2 | 马克图夫的刑期于六月结束 |
3 | 与原作不同,衣服没有脱下 |
4 | 这些物品的总价值达到一亿欧元。 |
5 | 桌上至少有两个选项。 |
6 | 这绝非不合理的狂热。 |
7 | 这种状况在1990年代随着国家分裂而改变。 |
8 | 期限是六个月。 |
9 | 但是,成本可能会高得多。 |
10 | 首府费拉坐落在一座小山上。 |
很好!这看起来好多了。我们已经从转录中删除了大部分特殊字符,并将其规范化为全小写。
在最终确定预处理之前,咨询目标语言的母语人士总是有利的,以查看文本是否可以进一步简化。对于这篇博客文章,Merve 很好心地快速浏览了一下,并指出土耳其语中像 â
这样的“带帽”字符已经不再使用了,可以用它们的“不带帽”对应物(例如 a
)替换。
这意味着我们应该将像 "yargı sistemi hâlâ sağlıksız"
这样的句子替换为 "yargı sistemi hala sağlıksız"
。
我们再编写一个简短的映射函数来进一步简化文本标签。
def replace_hatted_characters(batch):
batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)
在 CTC 中,通常将语音块分类为字母,所以我们在这里也这样做。让我们提取训练和测试数据中所有不同的字母,并从这个字母集合构建我们的词汇表。
我们编写了一个映射函数,将所有转录连接成一个长转录,然后将字符串转换为一组字符。重要的是将参数 batched=True
传递给 map(...)
函数,以便映射函数可以同时访问所有转录。
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
现在,我们创建训练集和测试集中所有不同字母的并集,并将结果列表转换为一个带索引的字典。
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
打印输出
{
' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26,
'ç': 27,
'ë': 28,
'ö': 29,
'ü': 30,
'ğ': 31,
'ı': 32,
'ş': 33,
'̇': 34
}
太棒了,我们看到所有字母都出现在数据集中(这并不奇怪),而且我们还提取了特殊字符 ""
和 '
。请注意,我们没有排除这些特殊字符,因为:
模型必须学会预测何时一个词结束,否则模型预测将总是一个字符序列,这将导致无法将词彼此分离。
人们应该始终牢记,预处理是模型训练前非常重要的一步。例如,我们不希望模型区分 a
和 A
仅仅因为我们忘记了规范化数据。a
和 A
之间的区别根本不取决于字母的“发音”,而更多地取决于语法规则——例如,在句子开头使用大写字母。因此,消除大写字母和非大写字母之间的差异是明智的,这样模型就能更容易地学习转录语音。
为了更清楚地表明 " "
拥有自己的标记类别,我们给它一个更明显的字符 |
。此外,我们还添加了一个“未知”标记,以便模型以后可以处理 Common Voice 训练集中未遇到的字符。
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
最后,我们还添加了一个填充标记,对应于 CTC 的“空白标记”。“空白标记”是 CTC 算法的核心组成部分。有关更多信息,请参阅此处的“对齐”部分。
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
太棒了,现在我们的词汇表已完成,包含 39 个标记,这意味着我们将添加到预训练 XLS-R 检查点之上的线性层将具有 39 的输出维度。
现在让我们将词汇表保存为 json 文件。
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
最后一步,我们使用 json 文件将词汇表加载到 Wav2Vec2CTCTokenizer
类的一个实例中。
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
如果想在本笔记本中将刚刚创建的分词器与微调模型重复使用,强烈建议将 tokenizer
上传到 Hugging Face Hub。我们将上传文件的仓库命名为 "wav2vec2-large-xlsr-turkish-demo-colab"
。
repo_name = "wav2vec2-large-xls-r-300m-tr-colab"
然后将分词器上传到 🤗 Hub。
tokenizer.push_to_hub(repo_name)
太好了,您可以在 https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab
下查看刚刚创建的仓库
创建 Wav2Vec2FeatureExtractor
语音是一种连续信号,为了能被计算机处理,它首先必须被离散化,这通常被称为**采样**。采样率在这里起着重要作用,因为它定义了每秒测量多少语音信号数据点。因此,更高的采样率会导致对*真实*语音信号更好的近似,但每秒也需要更多值。
预训练的检查点期望其输入数据以与训练时所用数据大致相同的分布进行采样。以两种不同速率采样的相同语音信号具有非常不同的分布。例如,采样率加倍会导致数据点时长加倍。因此,在微调 ASR 模型的预训练检查点之前,验证用于预训练模型的数据的采样率是否与用于微调模型的数据集的采样率匹配至关重要。
XLS-R 在 16kHz 采样率的 Babel、多语言 LibriSpeech (MLS)、Common Voice、VoxPopuli 和 VoxLingua107 音频数据上进行了预训练。Common Voice 的原始采样率为 48kHz,因此我们接下来需要将微调数据下采样到 16kHz。
Wav2Vec2FeatureExtractor
对象需要实例化以下参数:
feature_size
: 语音模型将特征向量序列作为输入。虽然此序列的长度显然不同,但特征大小不应改变。在 Wav2Vec2 的情况下,特征大小为 1,因为模型是在原始语音信号上训练的 。sampling_rate
: 模型训练时使用的采样率。padding_value
:对于批量推理,较短的输入需要用特定值填充。do_normalize
:输入是否应该进行零均值单位方差归一化。通常,语音模型在归一化输入后表现更好。return_attention_mask
: 模型是否应该使用attention_mask
进行批量推理。通常,XLS-R 模型检查点**总是**应该使用attention_mask
。
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
太好了,XLS-R 的特征提取管道由此完全定义!
为了提高用户友好性,特征提取器和分词器被封装在一个 Wav2Vec2Processor
类中,这样只需要一个 model
和 processor
对象。
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
接下来,我们可以准备数据集了。
预处理数据
到目前为止,我们还没有查看语音信号的实际值,而只是转录。除了 sentence
,我们的数据集中还包含另外两个列名 path
和 audio
。path
表示音频文件的绝对路径。让我们看看。
common_voice_train[0]["path"]
XLS-R 期望以 16 kHz 的一维数组格式输入。这意味着必须加载并重采样音频文件。
幸运的是,datasets
通过调用另一个列 audio
自动完成此操作。让我们试一试。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 48000}
太好了,我们可以看到音频文件已自动加载。这要归功于 datasets == 1.18.3
中引入的新的 "Audio"
特性,它可以在调用时动态加载和重采样音频文件。
在上面的示例中,我们可以看到音频数据以 48kHz 的采样率加载,而模型期望的采样率为 16kHz。我们可以通过使用 cast_column
将音频特征设置为正确的采样率
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
我们再看看 "audio"
。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 16000}
这似乎奏效了!让我们听几段音频文件,以便更好地理解数据集并验证音频是否正确加载。
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(common_voice_train)-1)
print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)
打印输出
sunulan bütün teklifler i̇ngilizce idi
数据现在似乎已正确加载并重新采样。
可以听到,说话人的语速、口音和背景环境等都在变化。然而,总体而言,录音听起来清晰可接受,这对于众包朗读语音语料库来说是意料之中的。
让我们做最后一次检查,确认数据准备是否正确,通过打印语音输入的形状、其转写文本以及相应的采样率。
rand_int = random.randint(0, len(common_voice_train)-1)
print("Target text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
打印输出
Target text: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
Input array shape: (71040,)
Sampling rate: 16000
好的!一切看起来都没问题——数据是一维数组,采样率总是 16kHz,目标文本也已规范化。
最后,我们可以利用 Wav2Vec2Processor
将数据处理成 Wav2Vec2ForCTC
训练所需的格式。为此,我们使用 Dataset 的 map(...)
函数。
首先,我们加载并重采样音频数据,只需调用 batch["audio"]
。其次,我们从加载的音频文件中提取 input_values
。在我们的例子中,Wav2Vec2Processor
只对数据进行归一化。然而,对于其他语音模型,此步骤可能包括更复杂的特征提取,例如 Log-Mel 特征提取。第三,我们将转录编码为标签 ID。
注意:此映射函数是 Wav2Vec2Processor
类应如何使用的良好示例。在“正常”情况下,调用 processor(...)
会被重定向到 Wav2Vec2FeatureExtractor
的调用方法。但是,当将处理器包装到 as_target_processor
上下文中时,相同的方法会被重定向到 Wav2Vec2CTCTokenizer
的调用方法。有关更多信息,请查阅文档。
def prepare_dataset(batch):
audio = batch["audio"]
# batched output is "un-batched"
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
让我们将数据准备函数应用到所有样本上。
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
注意:目前 datasets
利用 torchaudio
和 librosa
进行音频加载和重采样。如果您希望实现自己的自定义数据加载/采样,请随意使用 "path"
列并忽略 "audio"
列。
长输入序列需要大量内存。XLS-R 基于 self-attention
。对于长输入序列,内存需求与输入长度呈二次方关系(参见这篇 Reddit 帖子)。如果此演示因“内存不足”错误而崩溃,您可能需要取消注释以下行,以过滤所有长度超过 5 秒的训练序列。
#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
太棒了,现在我们准备好开始训练了!
训练
数据已处理完毕,我们已准备好开始设置训练管道。我们将使用 🤗 的 Trainer,为此我们主要需要执行以下操作:
定义一个数据整理器。与大多数 NLP 模型不同,XLS-R 的输入长度远大于输出长度。例如,输入长度为 50000 的样本的输出长度不超过 100。鉴于输入大小较大,动态填充训练批次效率更高,这意味着所有训练样本应仅填充到其批次中最长的样本,而不是整体最长的样本。因此,微调 XLS-R 需要一个特殊的填充数据整理器,我们将在下面定义。
评估指标。在训练期间,模型应以词错误率进行评估。我们应该相应地定义一个
compute_metrics
函数。加载预训练检查点。我们需要加载预训练检查点并对其进行正确配置以进行训练。
定义训练配置。
在微调模型后,我们将在测试数据上对其进行正确评估,并验证它确实学会了正确转写语音。
设置训练器
我们从定义数据整理器开始。数据整理器的代码是从 这个示例 复制的。
不深入细节,与常见的数据整理器不同,此数据整理器对 input_values
和 labels
进行不同的处理,因此对它们应用单独的填充函数(再次利用 XLS-R 处理器上下文管理器)。这是必要的,因为在语音中,输入和输出属于不同的模态,这意味着它们不应由相同的填充函数处理。与常见的数据整理器类似,标签中的填充标记用 -100
填充,以便在计算损失时**不**考虑这些标记。
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
接下来,定义评估指标。如前所述,ASR 中最主要的指标是词错误率 (WER),因此我们在这个 notebook 中也使用它。
wer_metric = load_metric("wer")
模型将返回一系列 logit 向量:,其中 和 。
一个 logit 向量 包含我们之前定义的词汇表中每个词的对数几率,因此 config.vocab_size
。我们对模型最可能的预测感兴趣,因此取 logits 的 argmax(...)
。此外,我们将编码后的标签通过将 -100
替换为 pad_token_id
来转换回原始字符串,并在解码 ID 时确保连续的标记在 CTC 样式中**不**被分组为相同的标记 。
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
现在,我们可以加载 Wav2Vec2-XLS-R-300M 的预训练检查点。分词器的 pad_token_id
必须用于定义模型的 pad_token_id
,或者在 Wav2Vec2ForCTC
的情况下,也定义 CTC 的*空白标记* 。为了节省 GPU 内存,我们启用了 PyTorch 的 梯度检查点,并将损失减少设置为“mean”。
由于数据集相当小(约 6 小时训练数据),并且 Common Voice 相当嘈杂,微调 Facebook 的 wav2vec2-xls-r-300m 检查点似乎需要一些超参数调整。因此,我不得不尝试不同的 dropout 值、SpecAugment 的掩码 dropout 率、层 dropout 和学习率,直到训练看起来足够稳定。
注意:如果使用本笔记本在 Common Voice 的另一种语言上训练 XLS-R,这些超参数设置可能效果不佳。请根据您的用例随意调整这些参数。
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
XLS-R 的第一个组件由一堆 CNN 层组成,用于从原始语音信号中提取具有声学意义但与上下文无关的特征。模型的这一部分在预训练期间已经充分训练,并且如论文所述,不再需要微调。因此,我们可以将*特征提取*部分的所有参数的 requires_grad
设置为 False
。
model.freeze_feature_extractor()
最后一步,我们定义所有与训练相关的参数。对其中一些参数进行更多解释:
group_by_length
通过将输入长度相似的训练样本分组到一个批次中,使训练更高效。这可以通过大大减少模型中无用填充标记的总数来显著加快训练时间。learning_rate
和weight_decay
经过启发式调整,直到微调变得稳定。请注意,这些参数强烈依赖于 Common Voice 数据集,并且对于其他语音数据集可能不是最优的。
关于其他参数的更多解释,可以查看文档。
在训练期间,每 400 个训练步骤将异步上传一个检查点到 Hub。这允许您在模型仍在训练时也可以使用演示小部件。
注意:如果不想将模型检查点上传到 Hub,只需设置 push_to_hub=False
。
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
现在,所有实例都可以传递给 Trainer,我们准备开始训练了!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
为了使模型独立于说话人语速,在 CTC 中,连续的相同标记简单地被分组为单个标记。然而,在解码时,编码的标签不应被分组,因为它们与模型的预测标记不对应,这就是为什么必须传递 group_tokens=False
参数。如果我们不传递此参数,像 "hello"
这样的词将被错误地编码并解码为 "helo"
。 空白标记允许模型通过强制在两个 l 之间插入空白标记来预测像 "hello"
这样的词。我们模型的 "hello"
的 CTC 符合预测将是 [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]
。
训练
训练将花费数小时,具体取决于分配给此笔记本的 GPU。虽然训练好的模型在土耳其语 Common Voice 的测试数据上取得了令人满意的结果,但它绝不是一个最优微调的模型。本笔记本的目的是演示如何在 ASR 数据集上微调 XLS-R XLSR-Wav2Vec2。
根据分配给您的 Google Colab 的 GPU,您可能会在这里看到“内存不足”错误。在这种情况下,最好将 per_device_train_batch_size
减小到 8 甚至更少,并增加 gradient_accumulation
。
trainer.train()
打印输出
训练损失 | 轮次 | 步骤 | 验证损失 | 词错误率 (Wer) |
---|---|---|---|---|
3.8842 | 3.67 | 400 | 0.6794 | 0.7000 |
0.4115 | 7.34 | 800 | 0.4304 | 0.4548 |
0.1946 | 11.01 | 1200 | 0.4466 | 0.4216 |
0.1308 | 14.68 | 1600 | 0.4526 | 0.3961 |
0.0997 | 18.35 | 2000 | 0.4567 | 0.3696 |
0.0784 | 22.02 | 2400 | 0.4193 | 0.3442 |
0.0633 | 25.69 | 2800 | 0.4153 | 0.3347 |
0.0498 | 29.36 | 3200 | 0.4077 | 0.3195 |
训练损失和验证 WER 都在稳步下降。
您现在可以将训练结果上传到 Hub,只需执行此指令即可
trainer.push_to_hub()
你现在可以和所有的朋友、家人、心爱的宠物分享这个模型:他们都可以用“your-username/the-name-you-picked”这个标识符来加载它,例如:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
有关如何微调 XLS-R 的更多示例,请参阅官方 🤗 Transformers 示例。
评估
最后,我们加载模型并验证它是否确实学会了转录土耳其语语音。
让我们首先加载预训练检查点。
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
现在,我们只取测试集中的第一个示例,通过模型运行它,并取 logits 的 argmax(...)
来检索预测的标记 ID。
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
强烈建议将 sampling_rate
参数传递给此函数。否则可能会导致难以调试的静默错误。
我们对 common_voice_test
进行了相当大的修改,因此数据集实例不再包含原始句子标签。因此,我们重新使用原始数据集来获取第一个示例的标签。
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
最后,我们可以对示例进行解码。
print("Prediction:")
print(processor.decode(pred_ids))
print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())
打印输出
预测字符串 | 目标文本 |
---|---|
hatta küçük şeyleri için bir büyt bir şeyleri kolluyor veyınıki çuk şeyler için bir bir mizi inciltiyoruz | hayatta küçük şeyleri kovalıyor ve yine küçük şeyler için birbirimizi incitiyoruz. |
好的!转录无疑可以从我们的预测中识别出来,但它还不够完美。模型训练时间再长一点,在数据预处理上投入更多时间,特别是使用语言模型进行解码,肯定会提高模型的整体性能。
然而,对于低资源语言的演示模型来说,结果还是相当不错的🤗。