Transformers 文档

骨干网络

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

骨干网络

更高级别的计算机视觉任务,例如对象检测或图像分割,会一起使用多个模型来生成预测。一个单独的模型用于骨干网络、颈部和头部。骨干网络从输入图像中提取有用的特征到特征图,颈部组合并处理特征图,头部使用它们进行预测。

使用 from_pretrained() 加载骨干网络,并使用 out_indices 参数来确定要从哪个索引给定的层提取特征图。

from transformers import AutoBackbone

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))

本指南介绍了骨干网络类、来自 timm 库的骨干网络,以及如何使用它们提取特征。

骨干网络类

有两个骨干网络类。

  • BackboneMixin 允许您加载骨干网络,并包含用于提取特征图和索引的函数。
  • BackboneConfigMixin 允许您设置骨干网络配置的特征图和索引。

请参阅 Backbone API 文档,以查看哪些模型支持骨干网络。

加载 Transformers 骨干网络有两种方法:AutoBackbone 和模型特定的骨干网络类。

AutoBackbone
模型特定的骨干网络

AutoClass API 会自动加载预训练的视觉模型,并将其作为骨干网络,如果它被支持,通过 from_pretrained() 加载。

out_indices 参数设置为您想要从中获取特征图的层。如果您知道层的名称,您也可以使用 out_features。这些参数可以互换使用,但如果您同时使用两者,请确保它们指向同一层。

当不使用 out_indicesout_features 时,骨干网络会返回最后一层的特征图。下面的示例代码使用 out_indices=(1,) 来获取第一层的特征图。

from transformers import AutoImageProcessor, AutoBackbone

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))

timm 骨干网络

timm 是用于训练和推理的视觉模型集合。Transformers 支持 timm 模型作为骨干网络,通过 TimmBackboneTimmBackboneConfig 类。

设置 use_timm_backbone=True 以加载预训练的 timm 权重,以及 use_pretrained_backbone 以使用预训练的或随机初始化的权重。

from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_timm_backbone=True, use_pretrained_backbone=True)
model = MaskFormerForInstanceSegmentation(config)

您也可以显式调用 TimmBackboneConfig 类来加载和创建预训练的 timm 骨干网络。

from transformers import TimmBackboneConfig

backbone_config = TimmBackboneConfig("resnet50", use_pretrained_backbone=True)

将骨干网络配置传递给模型配置,并使用骨干网络实例化模型头,MaskFormerForInstanceSegmentation

from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone_config=backbone_config)
model = MaskFormerForInstanceSegmentation(config)

特征提取

骨干网络用于提取图像特征。将图像传递到骨干网络以获取特征图。

加载和预处理图像,并将其传递给骨干网络。下面的示例从第一层提取特征图。

from transformers import AutoImageProcessor, AutoBackbone
import torch
from PIL import Image
import requests

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))
processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)

特征存储在输出的 feature_maps 属性中并从中访问。

feature_maps = outputs.feature_maps
list(feature_maps[0].shape)
[1, 96, 56, 56]
< > 在 GitHub 上更新