timm 文档
模型
并获得增强的文档体验
开始使用
模型
timm.create_model
< 源码 >( model_name: str pretrained: bool = False pretrained_cfg: typing.Union[str, typing.Dict[str, typing.Any], timm.models._pretrained.PretrainedCfg, NoneType] = None pretrained_cfg_overlay: typing.Optional[typing.Dict[str, typing.Any]] = None checkpoint_path: typing.Union[str, pathlib.Path, NoneType] = None cache_dir: typing.Union[str, pathlib.Path, NoneType] = None scriptable: typing.Optional[bool] = None exportable: typing.Optional[bool] = None no_jit: typing.Optional[bool] = None **kwargs: typing.Any )
参数
- model_name — 要实例化的模型名称。
- pretrained — 如果设置为 True,则加载预训练的 ImageNet-1k 权重。
- pretrained_cfg — 传入一个外部的 pretrained_cfg 用于模型。
- pretrained_cfg_overlay — 用这些键值对替换基础 pretrained_cfg 中的值。
- checkpoint_path — 模型初始化*后*要加载的检查点路径。
- cache_dir — 覆盖 Hugging Face Hub 和 Torch 检查点的模型缓存目录。
- scriptable — 设置层配置,使模型可进行 jit 脚本化(目前并非所有模型都支持)。
- exportable — 设置层配置,使模型可追踪/可导出为 ONNX 格式(尚未完全实现/遵循)。
- no_jit — 设置层配置,使模型不使用 jit 脚本化层(目前仅限激活层)。
创建一个模型。
查找模型的入口函数,并传递相关参数以创建一个新模型。
提示:**kwargs 将通过入口函数传递给 `timm.models.build_model_with_cfg()`,然后传递给模型类的 `__init__()`。在传递之前,值为 None 的 kwargs 会被删减。
关键字参数:drop_rate (float): 训练时分类器的 dropout 率。drop_path_rate (float): 训练时随机深度的 drop rate。global_pool (str): 分类器的全局池化类型。
示例
>>> from timm import create_model
>>> # Create a MobileNetV3-Large model with no pretrained weights.
>>> model = create_model('mobilenetv3_large_100')
>>> # Create a MobileNetV3-Large model with pretrained weights.
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000
>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10
>>> # Create a Dinov2 small model with pretrained weights and save weights in a custom directory.
>>> model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
>>> # Data will be stored at */data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/*
timm.list_models
< 源码 >( filter: typing.Union[str, typing.List[str]] = '' module: typing.Union[str, typing.List[str]] = '' pretrained: bool = False exclude_filters: typing.Union[str, typing.List[str]] = '' name_matches_cfg: bool = False include_tags: typing.Optional[bool] = None )
参数
- filter - 适用于 fnmatch 的通配符过滤字符串 —
- module - 将模型选择限制在特定子模块(例如 ‘vision_transformer’) —
- pretrained - 如果为 True,则仅包含具有有效预训练权重的模型 —
- exclude_filters - 在使用 filter 包含模型后,用于排除模型的通配符过滤器 —
- name_matches_cfg - 仅包含模型名称与 default_cfg 名称匹配的模型(排除一些别名)—
- include_tags - 在模型名称中包含预训练标签(model.tag)。如果为 None,则默认为 — 当 pretrained=True 时设置为 True,否则为 False(默认:None)
返回可用模型名称列表,按字母顺序排序
示例:model_list(‘gluon_resnet*’) — 返回所有以 ‘gluon_resnet’ 开头的模型 model_list(‘*resnext*’, ‘resnet’) — 返回 ‘resnet’ 模块中所有包含 ‘resnext’ 的模型