timm 文档

快速入门

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

快速入门

本快速入门旨在帮助准备好深入代码并查看如何将 timm 集成到其模型训练工作流程中的开发人员。

首先,您需要安装 timm。有关安装的更多信息,请参阅安装

pip install timm

加载预训练模型

预训练模型可以使用 create_model() 加载。

在这里,我们加载预训练的 mobilenetv3_large_100 模型。

>>> import timm

>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
注意:返回的 PyTorch 模型默认设置为训练模式,因此如果您计划将其用于推理,则必须对其调用 .eval()。

列出带有预训练权重的模型

要列出与 timm 打包的模型,您可以使用 list_models()。如果您指定 pretrained=True,此函数将仅返回具有关联的预训练权重的模型名称。

>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
    'adv_inception_v3',
    'cspdarknet53',
    'cspresnext50',
    'densenet121',
    'densenet161',
    'densenet169',
    'densenet201',
    'densenetblur121d',
    'dla34',
    'dla46_c',
]

您还可以列出名称中具有特定模式的模型。

>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
    'cspresnet50',
    'cspresnet50d',
    'cspresnet50w',
    'cspresnext50',
    ...
]

微调预训练模型

您可以通过更改分类器(最后一层)来微调任何预训练模型。

>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)

要在您自己的数据集上进行微调,您必须编写 PyTorch 训练循环或调整 timm训练脚本 以使用您的数据集。

使用预训练模型进行特征提取

在不修改网络的情况下,可以对任何模型调用 model.forward_features(input) 而不是常用的 model(input)。这将绕过网络的头部分类器和全局池化。

有关使用 timm 进行特征提取的更深入指南,请参阅特征提取

>>> import timm
>>> import torch
>>> x = torch.randn(1, 3, 224, 224)
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 960, 7, 7])

图像增强

要将图像转换为模型的有效输入,您可以使用 timm.data.create_transform(),并提供模型期望的所需 input_size

这将返回一个使用合理默认值的通用变换。

>>> timm.data.create_transform((3, 224, 224))
Compose(
    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

预训练模型具有特定的变换,这些变换在训练时应用于馈入模型的图像。如果您在图像上使用了错误的变换,则模型将不理解它所看到的内容!

要弄清楚给定预训练模型使用了哪些变换,我们可以首先查看其 pretrained_cfg

>>> model.pretrained_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv_stem',
 'classifier': 'classifier',
 'architecture': 'mobilenetv3_large_100'}

然后,我们可以通过使用 timm.data.resolve_data_config() 仅解析与数据相关的配置。

>>> timm.data.resolve_data_config(model.pretrained_cfg)
{'input_size': (3, 224, 224),
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'crop_pct': 0.875}

我们可以将此数据配置传递给 timm.data.create_transform() 以初始化模型的关联变换。

>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
>>> transform = timm.data.create_transform(**data_cfg)
>>> transform
Compose(
    Resize(size=256, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
注意:在这里,预训练模型的配置恰好与我们之前制作的通用配置相同。情况并非总是如此。因此,使用数据配置创建变换(就像我们在此处所做的那样)比使用通用变换更安全。

使用预训练模型进行推理

在这里,我们将把以上各节放在一起,并使用预训练模型进行推理。

首先,我们需要一张图像来进行推理。在这里,我们从网上加载一张叶子的图片

>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image

这是我们加载的图像

An Image from a link

现在,我们将再次创建我们的模型和变换。这一次,我们确保将模型设置为评估模式。

>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
>>> transform = timm.data.create_transform(
    **timm.data.resolve_data_config(model.pretrained_cfg)
)

我们可以通过将此图像传递给变换来为模型准备此图像。

>>> image_tensor = transform(image)
>>> image_tensor.shape
torch.Size([3, 224, 224])

现在我们可以将该图像传递给模型以获得预测。在这种情况下,我们使用 unsqueeze(0),因为模型期望批次维度。

>>> output = model(image_tensor.unsqueeze(0))
>>> output.shape
torch.Size([1, 1000])

为了获得预测概率,我们将 softmax 应用于输出。这为我们留下了一个形状为 (num_classes,) 的张量。

>>> probabilities = torch.nn.functional.softmax(output[0], dim=0)
>>> probabilities.shape
torch.Size([1000])

现在,我们将使用 torch.topk 找到前 5 个预测的类别索引和值。

>>> values, indices = torch.topk(probabilities, 5)
>>> indices
tensor([281, 282, 285, 673, 670])

如果我们检查顶部索引的 imagenet 标签,我们可以看到模型预测的内容……

>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
[{'label': 'tabby, tabby_cat', 'value': 0.5101025700569153},
 {'label': 'tiger_cat', 'value': 0.22490699589252472},
 {'label': 'Egyptian_cat', 'value': 0.1835290789604187},
 {'label': 'mouse, computer_mouse', 'value': 0.006752475164830685},
 {'label': 'motor_scooter, scooter', 'value': 0.004942195490002632}]
< > 在 GitHub 上更新