Transformers 文档
特征提取器
并获得增强的文档体验
开始使用
特征提取器
特征提取器将音频数据预处理成给定模型所需的正确格式。它接收原始音频信号并将其转换为可馈送到模型的张量。张量形状取决于模型,但特征提取器会根据你使用的模型为你正确预处理音频数据。特征提取器还包括填充、截断和重采样方法。
调用 from_pretrained() 从 Hugging Face Hub 或本地目录加载特征提取器及其预处理器配置。特征提取器和预处理器配置保存在 preprocessor_config.json 文件中。
将音频信号(通常存储在 `array` 中)传递给特征提取器,并将 `sampling_rate` 参数设置为预训练音频模型的采样率。重要的是,音频数据的采样率必须与预训练音频模型训练所用数据的采样率相匹配。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_sample
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
特征提取器返回一个输入,`input_values`,该输入已准备好供模型使用。
本指南将引导你了解特征提取器类以及如何预处理音频数据。
特征提取器类
Transformers 特征提取器继承自基类 SequenceFeatureExtractor,该类是 FeatureExtractionMixin 的子类。
- SequenceFeatureExtractor 提供了一个方法 pad(),用于将序列填充到特定长度,以避免序列长度不一致。
- FeatureExtractionMixin 提供了 from_pretrained() 和 save_pretrained() 来加载和保存特征提取器。
有两种方法可以加载特征提取器:AutoFeatureExtractor 和模型特定的特征提取器类。
AutoClass API 会自动为给定模型加载正确的特征提取器。
使用 from_pretrained() 加载特征提取器。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
预处理
特征提取器期望输入为特定形状的 PyTorch 张量。确切的输入形状可能因你使用的特定音频模型而异。
例如,Whisper 期望 `input_features` 是形状为 `(batch_size, feature_size, sequence_length)` 的张量,而 Wav2Vec2 期望 `input_values` 是形状为 `(batch_size, sequence_length)` 的张量。
特征提取器会为所使用的任何音频模型生成正确的输入形状。
特征提取器还设置音频文件的采样率(每秒采样的音频信号值数量)。你的音频数据的采样率必须与预训练模型训练所用数据集的采样率匹配。该值通常在模型卡中给出。
使用 from_pretrained() 加载数据集和特征提取器。
from datasets import load_dataset, Audio
from transformers import AutoFeatureExtractor
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
查看数据集中的第一个示例,并访问包含原始音频信号 `array` 的 `audio` 列。
dataset[0]["audio"]["array"]
array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ])
特征提取器将 `array` 预处理成给定音频模型的预期输入格式。使用 `sampling_rate` 参数设置合适的采样率。
processed_dataset = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_dataset
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
填充
音频序列长度不同是一个问题,因为 Transformers 期望所有序列具有相同的长度,以便进行批处理。长度不等的序列无法进行批处理。
dataset[0]["audio"]["array"].shape
(86699,)
dataset[1]["audio"]["array"].shape
(53248,)
填充会添加一个特殊的*填充标记*,以确保所有序列都具有相同的长度。特征提取器会将一个 `0`(解释为静音)添加到 `array` 中进行填充。设置 `padding=True` 可以将序列填充到批次中最长序列的长度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
padding=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(86699,)
processed_dataset["input_values"][1].shape
(86699,)
截断
模型只能处理到一定长度的序列,否则会崩溃。
截断是一种从序列中移除多余标记以确保其不超过最大长度的策略。将 `truncation=True` 设置为截断序列到 `max_length` 参数指定的长度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
max_length=50000,
truncation=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(50000,)
processed_dataset["input_values"][1].shape
(50000,)
重采样
Datasets 库也可以重采样音频数据,使其与音频模型预期的采样率匹配。这种方法在加载音频数据时即时进行重采样,这可能比就地重采样整个数据集更快。
你正在处理的音频数据集的采样率为 8kHz,而预训练模型期望的采样率为 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ]),
'sampling_rate': 8000}
对 `audio` 列调用 cast_column,将采样率上采样到 16kHz。
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
当你加载数据集样本时,它现在被重采样到 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 1.70562416e-05, 2.18727451e-04, 2.28099874e-04, ...,
3.43842403e-05, -5.96364771e-06, -1.76846661e-05]),
'sampling_rate': 16000}