timm 文档

特征提取

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

特征提取

timm 中的所有模型都具有一致的机制,用于从模型中获取除分类任务之外的各种类型的特征。

倒数第二层特征 (预分类器特征)

可以通过几种方式获得倒数第二层模型层的特征,而无需进行模型手术(尽管可以随意进行手术)。首先必须决定是否需要池化或非池化特征。

非池化

有三种方法可以获得非池化特征。最终的非池化特征有时被称为最后的隐藏状态。在 timm 中,这包括最终的归一化层(例如 ViT 风格的模型),但不包括池化/类别令牌选择和最终的后池化层。

无需修改网络,可以在任何模型上调用 model.forward_features(input) 而不是通常的 model(input)。这将绕过网络的头部分类器和全局池化。

如果想要显式修改网络以返回非池化特征,可以创建没有分类器和池化的模型,或者稍后移除它。这两种方法都会从网络中移除与分类器相关的参数。

forward_features()

>>> import torch
>>> import timm
>>> m = timm.create_model('xception41', pretrained=True)
>>> o = m(torch.randn(2, 3, 299, 299))
>>> print(f'Original shape: {o.shape}')
>>> o = m.forward_features(torch.randn(2, 3, 299, 299))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 2048, 10, 10])

创建时不使用分类器和池化

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Unpooled shape: torch.Size([2, 2048, 7, 7])

稍后移除它

>>> import torch
>>> import timm
>>> m = timm.create_model('densenet121', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0, '')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])

将非池化输出链接到分类器

最后的隐藏状态可以使用 forward_head() 函数反馈到模型的头部。

>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
>>> output = model.forward_features(torch.randn(2,3,256,256))
>>> print('Unpooled output shape:', output.shape)
>>> classified = model.forward_head(output)
>>> print('Classification output shape:', classified.shape)

输出

Unpooled output shape: torch.Size([2, 257, 512])
Classification output shape: torch.Size([2, 1000])

池化

要修改网络以返回池化特征,可以使用 forward_features() 并自行池化/展平结果,或者像上面一样修改网络,但保持池化不变。

创建时不使用分类器

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Pooled shape: torch.Size([2, 2048])

稍后移除它

>>> import torch
>>> import timm
>>> m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Pooled shape: torch.Size([2, 1024])

多尺度特征图 (特征金字塔)

对象检测、分割、关键点和各种密集像素任务需要访问来自骨干网络在多个尺度上的特征图。这通常通过修改原始分类网络来完成。由于每个网络在结构上差异很大,因此在任何给定的对象检测或分割库中,只看到少数几个骨干网络受到支持是很常见的。

timm 允许一个一致的接口,用于创建任何包含的模型作为特征骨干网络,这些骨干网络输出选定级别的特征图。

可以通过向任何 create_model 调用添加参数 features_only=True 来创建特征骨干网络。默认情况下,大多数具有特征层次结构的模型将输出最多 5 个特征,最多可减少 32 倍。但这因模型而异,有些模型具有较少的层次结构级别,而有些模型(如 ViT)具有大量非层次结构的特征图,它们默认输出最后 3 个。可以将 out_indices 参数传递给 create_model 以指定您想要的特征。

创建特征图提取模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnest26d', features_only=True, pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

torch.Size([2, 64, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 1024, 14, 14])
torch.Size([2, 2048, 7, 7])

查询特征信息

创建特征骨干网络后,可以对其进行查询,以向下游头部提供通道或分辨率缩减信息,而无需静态配置或硬编码的常量。`.feature_info` 属性是一个类,封装了有关特征提取点的信息。

>>> import torch
>>> import timm
>>> m = timm.create_model('regnety_032', features_only=True, pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [32, 72, 216, 576, 1512]
torch.Size([2, 32, 112, 112])
torch.Size([2, 72, 56, 56])
torch.Size([2, 216, 28, 28])
torch.Size([2, 576, 14, 14])
torch.Size([2, 1512, 7, 7])

选择特定特征级别或限制步幅

还有两个额外的创建参数会影响输出特征。

  • out_indices 选择要输出的索引
  • output_stride 限制网络的特征输出步幅(顺便说一下,在分类模式下也有效)

输出索引选择

所有模型都支持 out_indices 参数,但并非所有模型都具有相同的索引到特征步幅映射。查看代码或检查 feature_info 以进行比较。输出索引通常对应于 C(i+1) 特征级别(2^(i+1) 倍的缩减)。对于大多数 convnet 模型,索引 0 是步幅为 2 的特征,索引 4 是步幅为 32 的特征。对于许多 ViT 或 ViT-Conv 混合模型,可能存在许多或全部形状相同的特征图,或者层次结构和非层次结构特征图的组合。最好查看 feature_info 属性以查看特征数量、其相应的通道计数和缩减级别。

out_indices 支持负索引,这使得获取最后一个、倒数第二个等特征图变得容易。out_indices=(-2,) 将为任何模型返回倒数第二个特征图。

输出步幅(特征图扩张)

output_stride 是通过转换层以使用空洞卷积来实现的。这样做并不总是直接的,有些网络仅支持 output_stride=32

>>> import torch
>>> import timm
>>> m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> print(f'Feature reduction: {m.feature_info.reduction()}')
>>> o = m(torch.randn(2, 3, 320, 320))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [512, 2048]
Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])

灵活的中间特征图提取

除了在模型工厂中使用 features_only 之外,许多模型还支持 forward_intermediates() 方法,该方法提供了一种灵活的机制,用于提取中间特征图和最后的隐藏状态(可以链接到头部)。此外,此方法还支持一些特定于模型的功能,例如为某些模型返回类或 distill 前缀令牌。

forward_intermediates 函数一起的是 prune_intermediate_layers 函数,它允许从模型中修剪层,包括头部、最终规范化层和/或不需要的尾随块/阶段。

indices 参数用于 forward_intermediates()prune_intermediate_layers(),以选择要返回的特征或要移除的层。与 features_only API 的 out_indices 一样,indices 是特定于模型的,并选择返回哪些中间值。

在非层次结构的基于块的模型(例如 ViT)中,索引对应于块,在具有层次结构的阶段的模型中,它们通常对应于 stem + 每个层次结构阶段的输出。正索引(从头开始)和负索引(相对于结尾)都有效,并且 None 用于返回所有中间值。

prune_intermediate_layers() 调用返回一个索引变量,因为当模型被修剪时,负索引必须转换为绝对(正)索引。

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
for i, o in enumerate(intermediates):
    print(f'Feat index: {i}, shape: {o.shape}')
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('Original params:', sum([p.numel() for p in model.parameters()]))

indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)  # prune head, norm, last block
print('Pruned params:', sum([p.numel() for p in model.parameters()]))

intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True)  # return penultimate intermediate
for o in intermediates:    
    print(f'Feat shape: {o.shape}')
Original params: 38880232
Pruned params: 35212800
Feat shape: torch.Size([2, 512, 16, 16])
< > 在 GitHub 上更新