如何在 HuggingFace 生态系统中使用 SSAST 模型权重?

自监督音频频谱图 Transformer (SSAST) 模型提供了最先进的音频分类能力 [1, 2]。自监督学习允许使用未标记数据,从而提高模型性能和特征学习。
与有监督 AST 模型不同,SSAST 训练只能通过研究存储库中的原始实现进行,因此使用该模型可能很麻烦。然而,在预训练之后,权重可以轻松加载到 HuggingFace Transformers AST 实现中,以便在下游任务上进行微调,同时利用 HuggingFace 生态系统。
本教程将指导您完成将 SSAST 模型权重集成到 HuggingFace 生态系统中的过程,从而简化微调和部署,并使其更容易为更广泛的受众所用。
为什么要使用 SSAST 权重?
通过将 SSAST 权重加载到 HuggingFace Transformers AST 实现中,您可以
- 受益于无标签数据的自监督学习:通过使用原始 SSAST 实现预训练模型并在 HuggingFace 生态系统中微调模型,从而提高下游任务的模型性能。
- 利用 HuggingFace 强大且用户友好的工具:利用 HuggingFace 全面的工具套件进行模型训练、评估和部署。
- 摆脱“脆弱”的研究存储库:通过将模型集成到健壮且受良好支持的 HuggingFace 平台,避免兼容性和依赖性问题。
让我们开始逐步加载权重。
加载 SSAST 权重的分步指南
使用 pip 安装所有必需的包
pip install 'transformers[torch]'
1. 配置架构
首先,使用 HuggingFace transformers 库中的 ASTConfig
类配置我们要加载权重的 SSAST 模型架构。
from transformers import ASTConfig, ASTModel
config = ASTConfig(
architectures=["ASTModel"],
frequency_stride=16,
time_stride=16,
hidden_size=768,
max_length=1024,
num_attention_heads=12,
num_hidden_layers=12,
num_mel_bins=128,
qkv_bias=True
)
在上面的代码片段中,我配置了 16-16 补丁基础模型。权重已在 SSAST 存储库或此处提供下载链接。如果您已使用自定义架构预训练了自己的模型,则需要相应地配置它。
请查看 SSAST 存储库中的其他预训练模型。
2. 实例化 AST 模型
接下来,使用指定的配置创建 ASTModel
实例。
model = ASTModel(config=config)
如果您尚未使用原始 SSAST 存储库训练模型,您可以简单地下载存储库中可用的任何预训练模型的权重。
要将权重加载到 transformers AST 实现中,您需要从 state_dict
加载权重。
import torch
model.load_state_dict(torch.load("./SSAST-Base-Frame-400.pth"))
加载时,您会看到指示某些权重未使用的消息。当从在不同任务或使用不同架构上训练的检查点初始化 ASTModel
时,这是预期行为。
Some weights of the model checkpoint at ./SSAST-Base-Patch-400.pth were not used when initializing ASTModel: ['module.v.blocks.3.mlp.fc2.bias', 'module.gpredlayer.2.bias', 'module.v.blocks.10.attn.qkv.bias', 'module.v.blocks.11.mlp.fc1.bias', 'module.v.blocks.1.mlp.fc1.weight', ...
- This IS expected if you are initializing ASTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ASTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ASTModel were not initialized from the model checkpoint at ./SSAST-Base-Patch-400.pth and are newly initialized: ['encoder.layer.4.layernorm_before.weight', 'encoder.layer.1.attention.attention.value.weight', 'encoder.layer.10.attention.attention.query.weight', ...
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
主要区别在于层的命名约定。HuggingFace 的 ASTModel
与原始 SSAST 模型使用不同的命名方案。在 HuggingFace 实现中,“encoder.layer.[0–12]” 对应于 “module.v.blocks.[0–12]”。
在下一步中,我们将解决此问题。
3. 将 SSAST 状态字典转换为 HuggingFace 格式
要解决此问题,您可以将 SSAST 层名称映射到相应的 HuggingFace 层名称。下面是一个执行此转换的函数。
import torch
def convert_ssast_state_dict_to_astmodel(pretrained_dict, layers: int = 12):
conversion_dict = {
'module.v.cls_token': 'embeddings.cls_token',
'module.v.dist_token': 'embeddings.distillation_token',
'module.v.pos_embed': 'embeddings.position_embeddings',
'module.v.patch_embed.proj.weight': 'embeddings.patch_embeddings.projection.weight',
'module.v.patch_embed.proj.bias': 'embeddings.patch_embeddings.projection.bias',
'module.v.norm.weight': 'layernorm.weight',
'module.v.norm.bias': 'layernorm.bias',
}
for i in range(layers):
conversion_dict[
f'module.v.blocks.{i}.norm1.weight'] = f'encoder.layer.{i}.layernorm_before.weight'
conversion_dict[
f'module.v.blocks.{i}.norm1.bias'] = f'encoder.layer.{i}.layernorm_before.bias'
conversion_dict[f'module.v.blocks.{i}.attn.qkv.weight'] = [
f'encoder.layer.{i}.attention.attention.query.weight',
f'encoder.layer.{i}.attention.attention.key.weight',
f'encoder.layer.{i}.attention.attention.value.weight'
]
conversion_dict[f'module.v.blocks.{i}.attn.qkv.bias'] = [
f'encoder.layer.{i}.attention.attention.query.bias',
f'encoder.layer.{i}.attention.attention.key.bias',
f'encoder.layer.{i}.attention.attention.value.bias'
]
conversion_dict[
f'module.v.blocks.{i}.attn.proj.weight'] = f'encoder.layer.{i}.attention.output.dense.weight'
conversion_dict[
f'module.v.blocks.{i}.attn.proj.bias'] = f'encoder.layer.{i}.attention.output.dense.bias'
conversion_dict[
f'module.v.blocks.{i}.norm2.weight'] = f'encoder.layer.{i}.layernorm_after.weight'
conversion_dict[
f'module.v.blocks.{i}.norm2.bias'] = f'encoder.layer.{i}.layernorm_after.bias'
conversion_dict[
f'module.v.blocks.{i}.mlp.fc1.weight'] = f'encoder.layer.{i}.intermediate.dense.weight'
conversion_dict[
f'module.v.blocks.{i}.mlp.fc1.bias'] = f'encoder.layer.{i}.intermediate.dense.bias'
conversion_dict[
f'module.v.blocks.{i}.mlp.fc2.weight'] = f'encoder.layer.{i}.output.dense.weight'
conversion_dict[
f'module.v.blocks.{i}.mlp.fc2.bias'] = f'encoder.layer.{i}.output.dense.bias'
}
converted_dict = {}
for key, value in pretrained_dict.items():
if key in conversion_dict:
mapped_key = conversion_dict[key]
if isinstance(mapped_key, list):
# Assuming value is split equally among q, k, v if it's a concatenated tensor
split_size = value.shape[0] // 3
converted_dict[mapped_key[0]] = value[:split_size]
converted_dict[mapped_key[1]] = value[split_size:2 * split_size]
converted_dict[mapped_key[2]] = value[2 * split_size:]
else:
converted_dict[mapped_key] = value
return converted_dict
4. 加载转换后的状态字典
加载 SSAST 检查点,转换 state_dict
,并初始化 ASTModel
。
ssast_state_dict = torch.load("./SSAST-Base-Patch-400.pth")
converted = convert_ssast_state_dict_to_astmodel(ssast_state_dict)
model.load_state_dict(converted)
如果转换成功,您应该看到
“输出[1]:<所有键匹配成功>”
您现在可以使用带有 SSAST 预训练权重的 ASTModel
来执行任何任务,例如创建嵌入或将其集成到您的自定义训练管道中。
使用 SSAST 模型权重进行音频分类
如果要使用权重初始化音频分类器,则必须进行一些微小调整。
要使用 SSAST 权重实例化 ASTForAudioClassification
模型,请在编码器和嵌入层名称中添加 “audio_spectrogram_transformer.” 以使其正确匹配。例如
'module.v.blocks.0.norm1.weight' --> 'audio_spectrogram_transformer.encoder.layer.0.layernorm_before.weight'
'module.v.cls_token' --> 'audio_spectrogram_transformer.embeddings.cls_token'
'module.v.norm.weight'--> 'audio_spectrogram_transformer.layernorm.weight'
由于分类头将用零初始化,因此请务必在此之后调用 model.initialize()
。
from transformers import ASTForAudioClassification
model = ASTForAudioClassification(config=config)
model.load_state_dict(converted, strict=False)
model.initialize()
现在,您的 ASTForAudioClassification
模型已准备好在音频分类任务上进行微调。
了解如何在本文中微调 AST。
结论
本指南演示了将使用原始实现预训练的 SSAST 模型权重加载到 HuggingFace ASTModel
类中是多么容易。将 SSAST 模型权重集成到 HuggingFace 生态系统可以为 HuggingFace 生态系统中的 AST 训练或微调管道释放强大的自监督学习能力。
展望
在过去的 1.5 年里,我一直在研究 AST 模型,并开始撰写一系列关于如何训练模型以及如何针对音频领域中的特定问题进行调整的文章。这只是该系列的第一部分。如果您有兴趣扩展您在应用于音频的机器学习方面的知识,请务必查看我在 medium 上的音频文章列表。
第二篇文章是关于如何使用 Hugging Face Transformers 微调音频频谱图 Transformer (AST),并已由 Towards Data Science 发表。
建模愉快!
感谢阅读!我叫Marius Steger,是 @Renumics 的机器学习工程师——我们开发了 Spotlight,一个用于交互式数据探索和可视化的开源工具,它与 Hugging Face 数据集集成。如果您想了解更多关于该工具的信息,请查看我同事 Markus 的这篇社区文章。
参考文献
[1] Papers With Code 排行榜:AudioSet 上的音频分类
[2] Yuan Gong, Cheng-I Jeff Lai, Yu-An Chung, James Glass: SSAST: Self-SSAST: Self-Supervised Audio Spectrogram Transformer. (2021), arxiv