Transformers 文档
特征提取器
并获得增强的文档体验
开始使用
特征提取器
特征提取器将音频数据预处理为给定模型的正确格式。它接收原始音频信号,并将其转换为可以馈送到模型的张量。张量的形状取决于模型,但特征提取器会为您使用的模型正确地预处理音频数据。特征提取器还包括用于填充、截断和重采样的方法。
调用 from_pretrained() 从 Hugging Face Hub 或本地目录加载特征提取器及其预处理器配置。特征提取器和预处理器配置保存在 preprocessor_config.json 文件中。
将音频信号(通常存储在 array
中)传递给特征提取器,并将 sampling_rate
参数设置为预训练音频模型的采样率。重要的是音频数据的采样率与预训练音频模型训练数据的采样率相匹配。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_sample
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
特征提取器返回一个输入 input_values
,该输入已准备好供模型使用。
本指南将引导您了解特征提取器类以及如何预处理音频数据。
特征提取器类
Transformers 特征提取器继承自 SequenceFeatureExtractor 基类,该基类继承了 FeatureExtractionMixin。
- SequenceFeatureExtractor 提供了一个将序列 pad() 填充到特定长度的方法,以避免序列长度不均匀。
- FeatureExtractionMixin 提供了 from_pretrained() 和 save_pretrained() 来加载和保存特征提取器。
您可以通过两种方式加载特征提取器:AutoFeatureExtractor 和模型特定的特征提取器类。
AutoClass API 自动为给定模型加载正确的特征提取器。
使用 from_pretrained() 加载特征提取器。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
预处理
特征提取器期望输入为特定形状的 PyTorch 张量。确切的输入形状可能因您使用的特定音频模型而异。
例如,Whisper 期望 input_features
为形状为 (batch_size, feature_size, sequence_length)
的张量,但 Wav2Vec2 期望 input_values
为形状为 (batch_size, sequence_length)
的张量。
特征提取器为您正在使用的任何音频模型生成正确的输入形状。
特征提取器还设置音频文件的采样率(每秒采集的音频信号值数量)。您的音频数据的采样率必须与预训练模型训练所用数据集的采样率相匹配。此值通常在模型卡中给出。
使用 from_pretrained() 加载数据集和特征提取器。
from datasets import load_dataset, Audio
from transformers import AutoFeatureExtractor
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
查看数据集中的第一个示例,并访问包含原始音频信号 array
的 audio
列。
dataset[0]["audio"]["array"]
array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ])
特征提取器将 array
预处理为给定音频模型期望的输入格式。使用 sampling_rate
参数设置合适的采样率。
processed_dataset = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_dataset
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
填充
音频序列长度不同是一个问题,因为 Transformers 期望所有序列具有相同的长度,以便可以进行批处理。不均匀的序列长度无法进行批处理。
dataset[0]["audio"]["array"].shape
(86699,)
dataset[1]["audio"]["array"].shape
(53248,)
填充添加一个特殊的填充标记,以确保所有序列都具有相同的长度。特征提取器向 array
添加一个 0
(解释为静音)以进行填充。设置 padding=True
可将序列填充到批处理中最长序列的长度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
padding=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(86699,)
processed_dataset["input_values"][1].shape
(86699,)
截断
模型只能处理一定长度的序列,超过长度会崩溃。
截断是一种从序列中删除多余 token 的策略,以确保它不超过最大长度。设置 truncation=True
可将序列截断为 max_length
参数中指定的长度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
max_length=50000,
truncation=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(50000,)
processed_dataset["input_values"][1].shape
(50000,)
重采样
Datasets 库也可以重采样音频数据,以匹配音频模型期望的采样率。此方法在加载音频数据时动态重采样,这可能比就地重采样整个数据集更快。
您一直在处理的音频数据集的采样率为 8kHz,而预训练模型期望的采样率为 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ]),
'sampling_rate': 8000}
在 audio
列上调用 cast_column,将采样率上采样到 16kHz。
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
当您加载数据集样本时,它现在被重采样为 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 1.70562416e-05, 2.18727451e-04, 2.28099874e-04, ...,
3.43842403e-05, -5.96364771e-06, -1.76846661e-05]),
'sampling_rate': 16000}