timm 文档

ResNeSt

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

ResNeSt

ResNeStResNet 的一个变体,它堆叠了 Split-Attention 块。然后,基数组表示沿通道维度进行拼接V=Concat{V1,V2,,VK} V = \text{Concat} \{ V^{1},V^{2},\cdots,{V}^{K} \} 。与标准残差块一样,Split-Attention 块的最终输出Y Y 通过快捷连接生成Y=V+X Y=V+X ,如果输入和输出特征图具有相同的形状。对于带有步幅的块,将适当的转换T \mathcal{T} 应用于快捷连接以对齐输出形状Y=V+T(X) Y=V+\mathcal{T}(X) 。例如,T \mathcal{T} 可以是带步幅的卷积或组合的卷积-池化。

如何在图像上使用此模型?

加载预训练模型

>>> import timm
>>> model = timm.create_model('resnest101e', pretrained=True)
>>> model.eval()

加载并预处理图像

>>> import urllib
>>> from PIL import Image
>>> from timm.data import resolve_data_config
>>> from timm.data.transforms_factory import create_transform

>>> config = resolve_data_config({}, model=model)
>>> transform = create_transform(**config)

>>> url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
>>> urllib.request.urlretrieve(url, filename)
>>> img = Image.open(filename).convert('RGB')
>>> tensor = transform(img).unsqueeze(0) # transform and add batch dimension

获取模型预测结果

>>> import torch
>>> with torch.inference_mode():
...     out = model(tensor)
>>> probabilities = torch.nn.functional.softmax(out[0], dim=0)
>>> print(probabilities.shape)
>>> # prints: torch.Size([1000])

获取排名前 5 的预测类别名称

>>> # Get imagenet class mappings
>>> url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
>>> urllib.request.urlretrieve(url, filename)
>>> with open("imagenet_classes.txt", "r") as f:
...     categories = [s.strip() for s in f.readlines()]

>>> # Print top categories per image
>>> top5_prob, top5_catid = torch.topk(probabilities, 5)
>>> for i in range(top5_prob.size(0)):
...     print(categories[top5_catid[i]], top5_prob[i].item())
>>> # prints class names and probabilities like:
>>> # [('Samoyed', 0.6425196528434753), ('Pomeranian', 0.04062102362513542), ('keeshond', 0.03186424449086189), ('white wolf', 0.01739676296710968), ('Eskimo dog', 0.011717947199940681)]

将模型名称替换为您要使用的变体,例如 resnest101e。您可以在本页顶部的模型摘要中找到 ID。

要使用此模型提取图像特征,请遵循 timm 特征提取示例,只需更改你想使用的模型名称。

如何微调此模型?

你可以通过更改分类器(最后一层)来微调任何预训练模型。

>>> model = timm.create_model('resnest101e', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)

要在自己的数据集上进行微调,你需要编写一个训练循环或修改 timm 的训练脚本以使用你的数据集。

如何训练此模型?

你可以按照 timm 食谱脚本来重新训练一个新模型。

引用

@misc{zhang2020resnest,
      title={ResNeSt: Split-Attention Networks},
      author={Hang Zhang and Chongruo Wu and Zhongyue Zhang and Yi Zhu and Haibin Lin and Zhi Zhang and Yue Sun and Tong He and Jonas Mueller and R. Manmatha and Mu Li and Alexander Smola},
      year={2020},
      eprint={2004.08955},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
< > 在 GitHub 上更新