快速入门
本快速入门适用于准备深入代码并查看将 timm
集成到其模型训练工作流程的示例的开发人员。
首先,您需要安装 timm
。有关安装的更多信息,请参阅 安装。
pip install timm
加载预训练模型
可以使用 create_model() 加载预训练模型。
这里,我们加载预训练的 mobilenetv3_large_100
模型。
>>> import timm
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
列出具有预训练权重的模型
要列出与 timm
打包的模型,可以使用 list_models()。如果您指定 pretrained=True
,此函数将仅返回具有关联预训练权重的模型名称。
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
'adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
]
您还可以列出名称中具有特定模式的模型。
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
'cspresnet50',
'cspresnet50d',
'cspresnet50w',
'cspresnext50',
...
]
微调预训练模型
您可以通过更改分类器(最后一层)来微调任何预训练模型。
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
要在您自己的数据集上进行微调,您必须编写 PyTorch 训练循环或调整 timm
的 训练脚本 以使用您的数据集。
使用预训练模型进行特征提取
在不修改网络的情况下,可以在任何模型上调用 model.forward_features(input) 而不是通常的 model(input)。这将绕过网络的头部分类器和全局池化。
有关将 timm
用于特征提取的更深入指南,请参阅 特征提取。
>>> import timm
>>> import torch
>>> x = torch.randn(1, 3, 224, 224)
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 960, 7, 7])
图像增强
要将图像转换为模型的有效输入,可以使用 timm.data.create_transform(),提供模型期望的所需 input_size
。
这将返回一个使用合理默认值的通用转换。
>>> timm.data.create_transform((3, 224, 224))
Compose(
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
预训练模型具有在训练期间应用于馈送到其中的图像的特定转换。如果您在图像上使用错误的转换,模型将无法理解它正在看到的内容!
要弄清楚特定预训练模型使用了哪些转换,我们可以首先查看其 pretrained_cfg
>>> model.pretrained_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv_stem',
'classifier': 'classifier',
'architecture': 'mobilenetv3_large_100'}
然后,我们可以使用 timm.data.resolve_data_config() 仅解析与数据相关的配置。
>>> timm.data.resolve_data_config(model.pretrained_cfg)
{'input_size': (3, 224, 224),
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'crop_pct': 0.875}
我们可以将此数据配置传递给 timm.data.create_transform() 以初始化模型的关联转换。
>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
>>> transform = timm.data.create_transform(**data_cfg)
>>> transform
Compose(
Resize(size=256, interpolation=bicubic, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
使用预训练模型进行推理
这里,我们将把上面的部分放在一起,并使用一个预训练模型进行推理。
首先,我们需要一张图像来进行推理。这里,我们从网络上加载了一张叶子的图片
>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image
这是我们加载的图像
现在,我们将再次创建我们的模型和转换。这次,我们要确保将模型设置为评估模式。
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
>>> transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
我们可以通过将图像传递给转换来为模型准备它。
>>> image_tensor = transform(image)
>>> image_tensor.shape
torch.Size([3, 224, 224])
现在,我们可以将图像传递给模型以获得预测。我们在这里使用 `unsqueeze(0)`,因为模型期望一个批次维度。
>>> output = model(image_tensor.unsqueeze(0))
>>> output.shape
torch.Size([1, 1000])
为了获得预测概率,我们将 softmax 应用于输出。这将使我们得到一个形状为 `(num_classes,)` 的张量。
>>> probabilities = torch.nn.functional.softmax(output[0], dim=0)
>>> probabilities.shape
torch.Size([1000])
现在,我们将使用 `torch.topk` 找到前 5 个预测类索引和值。
>>> values, indices = torch.topk(probabilities, 5)
>>> indices
tensor([162, 166, 161, 164, 167])
如果我们检查前一个索引的 ImageNet 标签,我们可以看到模型预测了什么......
>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
[{'label': 'beagle', 'value': 0.8486220836639404},
{'label': 'Walker_hound, Walker_foxhound', 'value': 0.03753996267914772},
{'label': 'basset, basset_hound', 'value': 0.024628572165966034},
{'label': 'bluetick', 'value': 0.010317106731235981},
{'label': 'English_foxhound', 'value': 0.006958036217838526}]