timm 文档

特征提取

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

特征提取

timm 中的所有模型都具有获取模型各种类型特征的一致机制,这些特征用于分类之外的任务。

倒数第二层特征(分类器前特征)

可以通过几种方法获取倒数第二层模型的特征,而无需进行模型修改(尽管可以随意进行修改)。首先必须决定是否需要池化特征或非池化特征。

非池化

有三种方法可以获取非池化特征。最终的非池化特征有时被称为最后一个隐藏状态。在 timm 中,这包括最终归一化层(例如在 ViT 风格的模型中),但不包括池化/类别标记选择和最终池化后层。

无需修改网络,就可以在任何模型上调用 model.forward_features(input) 而不是通常的 model(input)。这将绕过网络的头部分类器和全局池化。

如果要明确修改网络以返回非池化特征,则可以创建没有分类器和池化的模型,或者稍后将其移除。这两种方法都会从网络中移除与分类器相关的参数。

forward_features()

>>> import torch
>>> import timm
>>> m = timm.create_model('xception41', pretrained=True)
>>> o = m(torch.randn(2, 3, 299, 299))
>>> print(f'Original shape: {o.shape}')
>>> o = m.forward_features(torch.randn(2, 3, 299, 299))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 2048, 10, 10])

创建无分类器和池化的模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Unpooled shape: torch.Size([2, 2048, 7, 7])

稍后移除它

>>> import torch
>>> import timm
>>> m = timm.create_model('densenet121', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0, '')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])

将非池化输出链接到分类器

可以使用 forward_head() 函数将最后一个隐藏状态反馈到模型的头部。

>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
>>> output = model.forward_features(torch.randn(2,3,256,256))
>>> print('Unpooled output shape:', output.shape)
>>> classified = model.forward_head(output)
>>> print('Classification output shape:', classified.shape)

输出

Unpooled output shape: torch.Size([2, 257, 512])
Classification output shape: torch.Size([2, 1000])

池化

要修改网络以返回池化特征,可以使用 forward_features() 并自行池化/展平结果,或者像上面一样修改网络,但保持池化不变。

不带分类器的创建

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Pooled shape: torch.Size([2, 2048])

稍后移除它

>>> import torch
>>> import timm
>>> m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Pooled shape: torch.Size([2, 1024])

多尺度特征图(特征金字塔)

目标检测、分割、关键点以及各种密集像素任务都需要访问来自主干网络的多尺度特征图。这通常通过修改原始分类网络来实现。由于每个网络在结构上差异很大,因此在任何给定的目标检测或分割库中只看到支持少数几个主干网络的情况并不少见。

timm 允许使用一致的接口创建任何包含的模型作为特征主干,这些主干输出选定级别的特征图。

可以通过向任何 create_model 调用添加参数 features_only=True 来创建特征主干。默认情况下,大多数具有特征层次结构的模型将输出最多 5 个特征,最多缩减 32 倍。但是这因模型而异,一些模型具有较少的层次结构级别,而一些模型(如 ViT)具有大量非层次结构特征图,它们默认输出最后 3 个。可以将 out_indices 参数传递给 create_model 以指定所需的特征。

创建特征图提取模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnest26d', features_only=True, pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

torch.Size([2, 64, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 1024, 14, 14])
torch.Size([2, 2048, 7, 7])

查询特征信息

创建特征主干后,可以查询它以向下游头部提供通道或分辨率缩减信息,而无需静态配置或硬编码常量。.feature_info 属性是一个类,封装了有关特征提取点的信息。

>>> import torch
>>> import timm
>>> m = timm.create_model('regnety_032', features_only=True, pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [32, 72, 216, 576, 1512]
torch.Size([2, 32, 112, 112])
torch.Size([2, 72, 56, 56])
torch.Size([2, 216, 28, 28])
torch.Size([2, 576, 14, 14])
torch.Size([2, 1512, 7, 7])

选择特定的特征级别或限制步长

还有两个额外的创建参数会影响输出特征。

  • out_indices 选择要输出的索引
  • output_stride 限制网络的特征输出步长(顺便说一下,在分类模式下也适用)

输出索引选择

所有模型都支持 out_indices 参数,但并非所有模型都具有相同的索引到特征步长映射。查看代码或检查 feature_info 以进行比较。输出索引通常对应于 C(i+1)th 特征级别(2^(i+1) 倍缩减)。对于大多数卷积神经网络模型,索引 0 是步长为 2 的特征,索引 4 是步长为 32 的特征。对于许多 ViT 或 ViT-Conv 混合模型,可能存在许多或所有相同形状的特征图,或者层次结构和非层次结构特征图的组合。最好查看 feature_info 属性以查看特征的数量、它们对应的通道数和缩减级别。

out_indices 支持负索引,这使得获取最后一个、倒数第二个等特征图变得很容易。out_indices=(-2,) 将为任何模型返回倒数第二个特征图。

输出步长(特征图扩张)

output_stride 是通过将层转换为使用扩张卷积来实现的。这样做并不总是简单明了,一些网络只支持 output_stride=32

>>> import torch
>>> import timm
>>> m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> print(f'Feature reduction: {m.feature_info.reduction()}')
>>> o = m(torch.randn(2, 3, 320, 320))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [512, 2048]
Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])

灵活的中间特征图提取

除了在模型工厂中使用 features_only 之外,许多模型还支持 forward_intermediates() 方法,该方法提供了一种灵活的机制来提取中间特征图和最后一个隐藏状态(可以将其链接到头部)。此外,此方法支持一些特定于模型的功能,例如为某些模型返回类或蒸馏前缀标记。

forward_intermediates 函数一起的是 prune_intermediate_layers 函数,它允许从模型中修剪层,包括头部、最终归一化和/或不需要的尾随块/阶段。

indices 参数用于 forward_intermediates()prune_intermediate_layers() 来选择要返回的特征或要移除的层。与 features_only API 的 out_indices 一样,indices 是特定于模型的,并选择返回哪些中间值。

在像 ViT 这样的非层次块模型中,索引对应于块,在具有层次阶段的模型中,它们通常对应于茎 + 每个层次阶段的输出。正索引(从开头开始)和负索引(相对于结尾)都适用,None 用于返回所有中间值。

prune_intermediate_layers() 调用返回一个索引变量,因为当模型被修剪时,负索引必须转换为绝对(正)索引。

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
for i, o in enumerate(intermediates):
    print(f'Feat index: {i}, shape: {o.shape}')
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('Original params:', sum([p.numel() for p in model.parameters()]))

indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)  # prune head, norm, last block
print('Pruned params:', sum([p.numel() for p in model.parameters()]))

intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True)  # return penultimate intermediate
for o in intermediates:    
    print(f'Feat shape: {o.shape}')
Original params: 38880232
Pruned params: 35212800
Feat shape: torch.Size([2, 512, 16, 16])
< > 在 GitHub 上更新