模型
timm.create_model
< 源代码 >( model_name: str pretrained: bool = False pretrained_cfg: Union = None pretrained_cfg_overlay: Optional = None checkpoint_path: str = '' scriptable: Optional = None exportable: Optional = None no_jit: Optional = None **kwargs )
创建模型。
查找模型的入口函数,并传递相关参数以创建新模型。
**kwargs 将通过入口函数传递给 `timm.models.build_model_with_cfg()`,然后传递给模型类的 `__init__`。kwargs 中的值设置为 None 会在传递之前被修剪。
关键字参数:drop_rate (float): 用于训练的分类器 dropout 率。drop_path_rate (float): 用于训练的随机深度 dropout 率。global_pool (str): 分类器全局池化类型。
示例
>>> from timm import create_model
>>> # Create a MobileNetV3-Large model with no pretrained weights.
>>> model = create_model('mobilenetv3_large_100')
>>> # Create a MobileNetV3-Large model with pretrained weights.
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000
>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10
timm.list_models
< 源代码 >( filter: Union = '' module: Union = '' pretrained: bool = False exclude_filters: Union = '' name_matches_cfg: bool = False include_tags: Optional = None )
参数
- filter - 与 fnmatch 匹配的通配符过滤器字符串 -
- module - 将模型选择限制在特定子模块(如 'vision_transformer') -
- pretrained - 如果为 True,则仅包含具有有效预训练权重的模型 -
- exclude_filters - 使用通配符过滤,在使用过滤器包含模型后排除模型 -
- name_matches_cfg - 仅包含模型名称与默认配置名称匹配的模型(排除某些别名) -
- include_tags - 在模型名称中包含预训练标签(model.tag)。如果为 None,则默认为预训练为 True 时设置为 True,否则为 False(默认:None)
返回可用的模型名称列表,按字母顺序排序
示例:model_list(’gluon_resnet’) - 返回所有以 ‘gluon_resnet’ 开头的模型 model_list(’resnext*, ‘resnet’) - 返回 ‘resnet’ 模块中包含 ‘resnext’ 的所有模型