timm 文档

特征提取

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

特征提取

timm 中所有模型都有一致的机制,用于从模型中获取各种类型的特征,以完成分类之外的任务。

倒数第二层特征(预分类器特征)

有多种方法可以获取倒数第二层模型的特征,而无需对模型进行修改(当然,您也可以随意进行修改)。首先必须决定是需要池化还是非池化的特征。

非池化

有三种方法可以获得非池化特征。最终的、非池化的特征有时被称为最后隐藏状态。在 timm 中,这包括到最终归一化层为止(例如在 ViT 风格的模型中),但不包括池化/类别词元选择和最终的后池化层。

无需修改网络,可以在任何模型上调用 `model.forward_features(input)`,而不是通常的 `model(input)`。这将绕过网络的头部(分类器)和全局池化。

如果想显式地修改网络以返回非池化特征,可以创建一个不带分类器和池化的模型,或者稍后移除它们。这两种方法都会从网络中移除与分类器相关的参数。

forward_features()

>>> import torch
>>> import timm
>>> m = timm.create_model('xception41', pretrained=True)
>>> o = m(torch.randn(2, 3, 299, 299))
>>> print(f'Original shape: {o.shape}')
>>> o = m.forward_features(torch.randn(2, 3, 299, 299))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 2048, 10, 10])

创建不带分类器和池化的模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Unpooled shape: torch.Size([2, 2048, 7, 7])

稍后移除

>>> import torch
>>> import timm
>>> m = timm.create_model('densenet121', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0, '')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])

将非池化输出链接到分类器

使用 `forward_head()` 函数,可以将最后的隐藏状态反馈给模型的头部。

>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
>>> output = model.forward_features(torch.randn(2,3,256,256))
>>> print('Unpooled output shape:', output.shape)
>>> classified = model.forward_head(output)
>>> print('Classification output shape:', classified.shape)

输出

Unpooled output shape: torch.Size([2, 257, 512])
Classification output shape: torch.Size([2, 1000])

池化

要修改网络以返回池化特征,可以使用 `forward_features()` 并自行对结果进行池化/展平,或者像上面那样修改网络但保持池化层不变。

创建不带分类器的模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Pooled shape: torch.Size([2, 2048])

稍后移除

>>> import torch
>>> import timm
>>> m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

输出

Original shape: torch.Size([2, 1000])
Pooled shape: torch.Size([2, 1024])

多尺度特征图(特征金字塔)

目标检测、分割、关键点检测以及各种密集像素任务需要从骨干网络中获取多尺度的特征图。这通常通过修改原始的分类网络来实现。由于每个网络的结构差异很大,任何给定的目标检测或分割库通常只支持少数几种骨干网络。

timm 提供了一个一致的接口,可以将任何包含的模型创建为特征骨干网络,并输出选定级别的特征图。

通过在任何 `create_model` 调用中添加 `features_only=True` 参数,可以创建一个特征骨干网络。默认情况下,大多数具有特征层次结构的模型将输出最多 5 个特征,最大缩减率为 32。然而,这因模型而异,有些模型层次较少,而有些(如 ViT)则有大量非层次化的特征图,它们默认输出最后 3 个。可以向 `create_model` 传递 `out_indices` 参数来指定你想要的特征。

创建一个特征图提取模型

>>> import torch
>>> import timm
>>> m = timm.create_model('resnest26d', features_only=True, pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

torch.Size([2, 64, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 1024, 14, 14])
torch.Size([2, 2048, 7, 7])

查询特征信息

特征骨干网络创建后,可以查询其通道或分辨率缩减信息,以提供给下游的头部,而无需静态配置或硬编码常量。`.feature_info` 属性是一个封装了特征提取点信息的类。

>>> import torch
>>> import timm
>>> m = timm.create_model('regnety_032', features_only=True, pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [32, 72, 216, 576, 1512]
torch.Size([2, 32, 112, 112])
torch.Size([2, 72, 56, 56])
torch.Size([2, 216, 28, 28])
torch.Size([2, 576, 14, 14])
torch.Size([2, 1512, 7, 7])

选择特定特征级别或限制步幅

还有两个额外的创建参数会影响输出的特征。

  • out_indices 选择要输出的索引
  • output_stride 限制网络的特征输出步幅(顺便说一下,在分类模式下也有效)

输出索引选择

`out_indices` 参数受所有模型支持,但并非所有模型都具有相同的索引到特征步幅的映射关系。请查看代码或检查 `feature_info` 进行比较。输出索引通常对应于 `C(i+1)` 特征级别(即 `2^(i+1)` 倍的缩减)。对于大多数卷积神经网络模型,索引 0 是步幅为 2 的特征,索引 4 是步幅为 32 的特征。对于许多 ViT 或 ViT-Conv 混合模型,可能有很多或所有特征图都具有相同的形状,或者是层次化和非层次化特征图的组合。最好查看 `feature_info` 属性,以了解特征的数量、它们对应的通道数和缩减级别。

out_indices 支持负索引,这使得获取最后一个、倒数第二个等特征图变得容易。out_indices=(-2,) 将返回任何模型的倒数第二个特征图。

输出步幅(特征图扩张)

output_stride 是通过将层转换为使用扩张卷积来实现的。这样做并不总是直接的,一些网络仅支持 output_stride=32

>>> import torch
>>> import timm
>>> m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> print(f'Feature reduction: {m.feature_info.reduction()}')
>>> o = m(torch.randn(2, 3, 320, 320))
>>> for x in o:
...     print(x.shape)

输出

Feature channels: [512, 2048]
Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])

灵活的中间特征图提取

除了使用模型工厂的 `features_only` 参数外,许多模型还支持一个 `forward_intermediates()` 方法,它提供了一种灵活的机制来提取中间特征图和最后的隐藏状态(可以链接到头部)。此外,该方法还支持一些模型特定的特性,例如为某些模型返回类别或蒸馏前缀词元。

与 `forward_intermediates` 函数相伴的是一个 `prune_intermediate_layers` 函数,它允许你从模型中修剪层,包括头部、最终归一化层和/或不需要的尾部块/阶段。

一个 `indices` 参数同时用于 `forward_intermediates()` 和 `prune_intermediate_layers()`,以选择要返回的特征或要移除的层。与 `features_only` API 的 `out_indices` 一样,`indices` 是模型特定的,并选择返回哪些中间结果。

在非层次化的基于块的模型(如 ViT)中,索引对应于块;在具有层次化阶段的模型中,它们通常对应于主干(stem)和每个层次化阶段的输出。支持正向(从头开始)和负向(相对于结尾)索引,而 `None` 用于返回所有中间结果。

在修剪模型时,`prune_intermediate_layers()` 调用会返回一个索引变量,因为负索引必须转换为绝对(正)索引。

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
for i, o in enumerate(intermediates):
    print(f'Feat index: {i}, shape: {o.shape}')
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('Original params:', sum([p.numel() for p in model.parameters()]))

indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)  # prune head, norm, last block
print('Pruned params:', sum([p.numel() for p in model.parameters()]))

intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True)  # return penultimate intermediate
for o in intermediates:    
    print(f'Feat shape: {o.shape}')
Original params: 38880232
Pruned params: 35212800
Feat shape: torch.Size([2, 512, 16, 16])
< > 在 GitHub 上更新