timm 文档

模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

模型

timm.create_model

< >

( model_name: str pretrained: bool = False pretrained_cfg: typing.Union[str, typing.Dict[str, typing.Any], timm.models._pretrained.PretrainedCfg, NoneType] = None pretrained_cfg_overlay: typing.Optional[typing.Dict[str, typing.Any]] = None checkpoint_path: typing.Union[str, pathlib.Path, NoneType] = None cache_dir: typing.Union[str, pathlib.Path, NoneType] = None scriptable: typing.Optional[bool] = None exportable: typing.Optional[bool] = None no_jit: typing.Optional[bool] = None **kwargs )

参数

  • model_name — 要实例化的模型名称。
  • pretrained — 如果设置为 True,则加载 ImageNet-1k 预训练权重。
  • pretrained_cfg — 传入模型的外部 pretrained_cfg。
  • pretrained_cfg_overlay — 使用这些替换基本 pretrained_cfg 中的键值对。
  • checkpoint_path — 初始化模型之后要加载的检查点路径。
  • cache_dir — 覆盖 Hugging Face Hub 和 Torch 检查点的模型缓存目录。
  • scriptable — 设置图层配置,使模型可用于 jit 脚本化(并非所有模型都适用)。
  • exportable — 设置图层配置,使模型可追踪/可导出为 ONNX 格式(尚未完全实现/遵守)。
  • no_jit — 设置图层配置,使模型不使用 jit 脚本化图层(目前仅限激活函数)。

创建一个模型。

查找模型的入口点函数并传递相关参数以创建一个新模型。

提示:**kwargs 将通过入口点函数传递到 timm.models.build_model_with_cfg(),然后传递到模型类的 init()。设置为 None 的 kwargs 值在传递之前会被修剪。

关键字参数: drop_rate (float): 训练的分类器 dropout 率。 drop_path_rate (float): 训练的随机深度 dropout 率。 global_pool (str): 分类器全局池化类型。

示例

>>> from timm import create_model

>>> # Create a MobileNetV3-Large model with no pretrained weights.
>>> model = create_model('mobilenetv3_large_100')

>>> # Create a MobileNetV3-Large model with pretrained weights.
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000

>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10

>>> # Create a Dinov2 small model with pretrained weights and save weights in a custom directory.
>>> model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
>>> # Data will be stored at */data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/*

timm.list_models

< >

( filter: typing.Union[str, typing.List[str]] = '' module: typing.Union[str, typing.List[str]] = '' pretrained: bool = False exclude_filters: typing.Union[str, typing.List[str]] = '' name_matches_cfg: bool = False include_tags: typing.Optional[bool] = None )

参数

  • filter - 与 fnmatch 一起使用的通配符过滤器字符串 —
  • module - 将模型选择限制为特定的子模块(例如 ‘vision_transformer’)—
  • pretrained - 如果为 True,则仅包含具有有效预训练权重的模型 —
  • exclude_filters - 使用过滤器包含模型后要排除的通配符过滤器 —
  • name_matches_cfg - 仅包含 model_name 与 default_cfg 名称匹配的模型(排除某些别名)—
  • include_tags - 在模型名称中包含预训练标签 (model.tag)。 如果为 None,则默认值 — 当 pretrained=True 时设置为 True,否则为 False (默认值: None)

返回可用模型名称的列表,按字母顺序排序

示例: model_list(‘gluon_resnet’) — 返回所有以 ‘gluon_resnet’ 开头的模型 model_list(’resnext*, ‘resnet’) — 返回 ‘resnet’ 模块中所有包含 ‘resnext’ 的模型

< > 在 GitHub 上更新