timm 文档
模型
并获得增强的文档体验
开始使用
模型
timm.create_model
< source >( model_name: str pretrained: bool = False pretrained_cfg: typing.Union[str, typing.Dict[str, typing.Any], timm.models._pretrained.PretrainedCfg, NoneType] = None pretrained_cfg_overlay: typing.Optional[typing.Dict[str, typing.Any]] = None checkpoint_path: typing.Union[str, pathlib.Path, NoneType] = None cache_dir: typing.Union[str, pathlib.Path, NoneType] = None scriptable: typing.Optional[bool] = None exportable: typing.Optional[bool] = None no_jit: typing.Optional[bool] = None **kwargs )
参数
- model_name — 要实例化的模型名称。
- pretrained — 如果设置为 True,则加载 ImageNet-1k 预训练权重。
- pretrained_cfg — 传入模型的外部 pretrained_cfg。
- pretrained_cfg_overlay — 使用这些替换基本 pretrained_cfg 中的键值对。
- checkpoint_path — 初始化模型之后要加载的检查点路径。
- cache_dir — 覆盖 Hugging Face Hub 和 Torch 检查点的模型缓存目录。
- scriptable — 设置图层配置,使模型可用于 jit 脚本化(并非所有模型都适用)。
- exportable — 设置图层配置,使模型可追踪/可导出为 ONNX 格式(尚未完全实现/遵守)。
- no_jit — 设置图层配置,使模型不使用 jit 脚本化图层(目前仅限激活函数)。
创建一个模型。
查找模型的入口点函数并传递相关参数以创建一个新模型。
提示:**kwargs 将通过入口点函数传递到 timm.models.build_model_with_cfg()
,然后传递到模型类的 init()。设置为 None 的 kwargs 值在传递之前会被修剪。
关键字参数: drop_rate (float): 训练的分类器 dropout 率。 drop_path_rate (float): 训练的随机深度 dropout 率。 global_pool (str): 分类器全局池化类型。
示例
>>> from timm import create_model
>>> # Create a MobileNetV3-Large model with no pretrained weights.
>>> model = create_model('mobilenetv3_large_100')
>>> # Create a MobileNetV3-Large model with pretrained weights.
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000
>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10
>>> # Create a Dinov2 small model with pretrained weights and save weights in a custom directory.
>>> model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
>>> # Data will be stored at */data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/*
timm.list_models
< source >( filter: typing.Union[str, typing.List[str]] = '' module: typing.Union[str, typing.List[str]] = '' pretrained: bool = False exclude_filters: typing.Union[str, typing.List[str]] = '' name_matches_cfg: bool = False include_tags: typing.Optional[bool] = None )
参数
- filter - 与 fnmatch 一起使用的通配符过滤器字符串 —
- module - 将模型选择限制为特定的子模块(例如 ‘vision_transformer’)—
- pretrained - 如果为 True,则仅包含具有有效预训练权重的模型 —
- exclude_filters - 使用过滤器包含模型后要排除的通配符过滤器 —
- name_matches_cfg - 仅包含 model_name 与 default_cfg 名称匹配的模型(排除某些别名)—
- include_tags - 在模型名称中包含预训练标签 (model.tag)。 如果为 None,则默认值 — 当 pretrained=True 时设置为 True,否则为 False (默认值: None)
返回可用模型名称的列表,按字母顺序排序
示例: model_list(‘gluon_resnet’) — 返回所有以 ‘gluon_resnet’ 开头的模型 model_list(’resnext*, ‘resnet’) — 返回 ‘resnet’ 模块中所有包含 ‘resnext’ 的模型