timm 文档

对抗性Inception v3

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

对抗性Inception v3

Inception v3 是 Inception 系列中的卷积神经网络架构,它进行了多项改进,包括使用标签平滑、分解 7 x 7 卷积,以及使用辅助分类器以将标签信息传播到网络更低层(同时在侧头层中使用批量归一化)。其关键构建块是Inception 模块

该特定模型用于对抗性样本研究(对抗性训练)。

该模型的权重从 Tensorflow/Models 移植而来。

如何将此模型应用于图像?

加载预训练模型

>>> import timm
>>> model = timm.create_model('adv_inception_v3', pretrained=True)
>>> model.eval()

加载并预处理图像

>>> import urllib
>>> from PIL import Image
>>> from timm.data import resolve_data_config
>>> from timm.data.transforms_factory import create_transform

>>> config = resolve_data_config({}, model=model)
>>> transform = create_transform(**config)

>>> url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
>>> urllib.request.urlretrieve(url, filename)
>>> img = Image.open(filename).convert('RGB')
>>> tensor = transform(img).unsqueeze(0) # transform and add batch dimension

获取模型预测结果

>>> import torch
>>> with torch.inference_mode():
...     out = model(tensor)
>>> probabilities = torch.nn.functional.softmax(out[0], dim=0)
>>> print(probabilities.shape)
>>> # prints: torch.Size([1000])

获取排名前 5 的预测类别名称

>>> # Get imagenet class mappings
>>> url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
>>> urllib.request.urlretrieve(url, filename) 
>>> with open("imagenet_classes.txt", "r") as f:
...     categories = [s.strip() for s in f.readlines()]

>>> # Print top categories per image
>>> top5_prob, top5_catid = torch.topk(probabilities, 5)
>>> for i in range(top5_prob.size(0)):
...     print(categories[top5_catid[i]], top5_prob[i].item())
>>> # prints class names and probabilities like:
>>> # [('Samoyed', 0.6425196528434753), ('Pomeranian', 0.04062102362513542), ('keeshond', 0.03186424449086189), ('white wolf', 0.01739676296710968), ('Eskimo dog', 0.011717947199940681)]

将模型名称替换为您要使用的变体,例如 adv_inception_v3。您可以在本页顶部的模型摘要中找到 ID。

要使用此模型提取图像特征,请遵循 timm 特征提取示例,只需更改你想使用的模型名称。

如何微调此模型?

你可以通过更改分类器(最后一层)来微调任何预训练模型。

>>> model = timm.create_model('adv_inception_v3', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)

要在自己的数据集上进行微调,你需要编写一个训练循环或修改 timm 的训练脚本以使用你的数据集。

如何训练此模型?

你可以按照 timm 食谱脚本来重新训练一个新模型。

引用

@article{DBLP:journals/corr/abs-1804-00097,
  author    = {Alexey Kurakin and
               Ian J. Goodfellow and
               Samy Bengio and
               Yinpeng Dong and
               Fangzhou Liao and
               Ming Liang and
               Tianyu Pang and
               Jun Zhu and
               Xiaolin Hu and
               Cihang Xie and
               Jianyu Wang and
               Zhishuai Zhang and
               Zhou Ren and
               Alan L. Yuille and
               Sangxia Huang and
               Yao Zhao and
               Yuzhe Zhao and
               Zhonglin Han and
               Junjiajia Long and
               Yerkebulan Berdibekov and
               Takuya Akiba and
               Seiya Tokui and
               Motoki Abe},
  title     = {Adversarial Attacks and Defences Competition},
  journal   = {CoRR},
  volume    = {abs/1804.00097},
  year      = {2018},
  url       = {http://arxiv.org/abs/1804.00097},
  archivePrefix = {arXiv},
  eprint    = {1804.00097},
  timestamp = {Thu, 31 Oct 2019 16:31:22 +0100},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1804-00097.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
< > 在 GitHub 上更新