Transformers 文档

骨干网络

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

骨干网络

更高级别的计算机视觉任务,例如目标检测或图像分割,会使用多个模型协同生成预测。其中一个单独的模型用于**骨干网络**、颈部和头部。骨干网络从输入图像中提取有用特征到特征图,颈部组合并处理特征图,头部则利用这些特征图进行预测。

使用 from_pretrained() 加载骨干网络,并使用 out_indices 参数确定从哪个索引层提取特征图。

from transformers import AutoBackbone

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))

本指南介绍了骨干网络类、来自 timm 库的骨干网络以及如何使用它们提取特征。

骨干网络类

骨干网络类有两个。

  • BackboneMixin 允许您加载骨干网络,并包含提取特征图和索引的函数。
  • BackboneConfigMixin 允许您设置骨干网络配置的特征图和索引。

请参阅骨干网络 API 文档以查看哪些模型支持骨干网络。

有两种加载 Transformers 骨干网络的方法:AutoBackbone 和模型特定的骨干网络类。

AutoBackbone
模型特定的骨干网络

AutoClass API 在支持的情况下,会自动使用 from_pretrained() 加载预训练的视觉模型作为骨干网络。

out_indices 参数设置为您希望从中获取特征图的层。如果您知道层的名称,也可以使用 out_features。这些参数可以互换使用,但如果两者都使用,请确保它们指向同一层。

当未使用 out_indicesout_features 时,骨干网络会返回最后一层的特征图。以下示例代码使用 out_indices=(1,) 从第一层获取特征图。

from transformers import AutoImageProcessor, AutoBackbone

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))

timm 骨干网络

timm 是用于训练和推理的视觉模型集合。Transformers 支持 timm 模型作为骨干网络,通过 TimmBackboneTimmBackboneConfig 类实现。

设置 use_timm_backbone=True 以加载预训练的 timm 权重,并设置 use_pretrained_backbone 以使用预训练或随机初始化的权重。

from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_timm_backbone=True, use_pretrained_backbone=True)
model = MaskFormerForInstanceSegmentation(config)

您也可以显式调用 TimmBackboneConfig 类来加载和创建预训练的 timm 骨干网络。

from transformers import TimmBackboneConfig

backbone_config = TimmBackboneConfig("resnet50", use_pretrained_backbone=True)

将骨干网络配置传递给模型配置,并使用骨干网络实例化模型头部,即 MaskFormerForInstanceSegmentation

from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone_config=backbone_config)
model = MaskFormerForInstanceSegmentation(config)

特征提取

骨干网络用于提取图像特征。通过骨干网络传递图像以获取特征图。

加载并预处理图像,然后将其传递给骨干网络。以下示例从第一层提取特征图。

from transformers import AutoImageProcessor, AutoBackbone
import torch
from PIL import Image
import requests

model = AutoBackbone.from_pretrained("microsoft/swin-tiny-patch4-window7-224", out_indices=(1,))
processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)

特征存储在输出的 feature_maps 属性中并从中访问。

feature_maps = outputs.feature_maps
list(feature_maps[0].shape)
[1, 96, 56, 56]
< > 在 GitHub 上更新