将 🤗 Transformers 模型导出到 ONNX
🤗 Transformers 提供了一个 transformers.onnx
包,使您能够通过利用配置对象将模型检查点转换为 ONNX 图。
有关导出 🤗 Transformers 模型的更多详细信息,请参阅指南。
ONNX 配置
我们提供了三个抽象类,您应该根据要导出的模型架构类型继承它们
- 基于编码器的模型继承自 OnnxConfig
- 基于解码器的模型继承自 OnnxConfigWithPast
- 编码器-解码器模型继承自 OnnxSeq2SeqConfigWithPast
OnnxConfig
类 transformers.onnx.OnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'default' patching_specs: List = None )
用于 ONNX 可导出模型的基类,描述了如何通过 ONNX 格式导出模型的元数据。
flatten_output_collection_property
< 源代码 >( name: str field: Iterable ) → (Dict[str, Any])
返回
(Dict[str, Any])
具有扁平化结构和映射此新结构的键的输出。
展平任何潜在的嵌套结构,并使用结构中元素的索引扩展字段的名称。
为特定模型实例化 OnnxConfig
generate_dummy_inputs
< 源代码 >( preprocessor: Union batch_size: int = -1 seq_length: int = -1 num_choices: int = -1 is_pair: bool = False framework: Optional = None num_channels: int = 3 image_width: int = 40 image_height: int = 40 sampling_rate: int = 22050 time_duration: float = 5.0 frequency: int = 220 tokenizer: PreTrainedTokenizerBase = None )
参数
- batch_size (
int
,可选,默认为 -1) — 导出模型的批量大小(-1 表示动态轴)。 - num_choices (
int
,可选,默认为 -1) — 为多选任务提供的候选答案数量(-1 表示动态轴)。 - seq_length (
int
,可选,默认为 -1) — 导出模型的序列长度(-1 表示动态轴)。 - is_pair (
bool
,可选,默认为False
) — 指示输入是否为一对(句子 1,句子 2) - framework (
TensorType
,可选,默认为None
) — 分词器将为其生成张量的框架(PyTorch 或 TensorFlow)。 - num_channels (
int
,可选,默认为 3) — 生成图像的通道数。 - image_width (
int
,可选,默认为 40) — 生成图像的宽度。 - image_height (
int
,可选,默认为 40) — 生成图像的高度。 - sampling_rate (
int
,可选,默认为 22050) — 用于生成音频数据的采样率。 - time_duration (
float
,可选,默认为 5.0) — 用于生成音频数据的总采样秒数。 - frequency (
int
,可选,默认为 220) — 生成的音频所需的自然频率。
生成要提供给特定框架的 ONNX 导出器的输入
generate_dummy_inputs_onnxruntime
< 源代码 >( reference_model_inputs: Mapping ) → Mapping[str, Tensor]
使用参考模型输入为 ONNX Runtime 生成输入。 覆盖此方法以使用 seq2seq 模型运行推理,这些模型的编码器和解码器作为单独的 ONNX 文件导出。
指示模型是否需要使用外部数据格式的标志
OnnxConfigWithPast
类 transformers.onnx.OnnxConfigWithPast
< 源代码 >( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )
fill_with_past_key_values_
< 源代码 >( inputs_or_outputs: Mapping direction: str inverted_values_shape: bool = False )
使用 past_key_values 动态轴填充 input_or_outputs 映射。
实例化一个 use_past
属性设置为 True 的 OnnxConfig
OnnxSeq2SeqConfigWithPast
类 transformers.onnx.OnnxSeq2SeqConfigWithPast
< 源代码 >( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )
ONNX 功能
每个 ONNX 配置都与一组*功能*相关联,这些功能使您能够导出用于不同类型拓扑或任务的模型。
FeaturesManager
检查模型是否具有请求的功能。
determine_framework
< 源代码 >( model: str framework: str = None )
确定用于导出的框架。
优先级顺序如下
- 用户通过
framework
输入。 - 如果提供了本地检查点,则使用与检查点相同的框架。
- 环境中可用的框架,优先考虑 PyTorch
get_config
< 源代码 >( model_type: str feature: str ) → OnnxConfig
获取 model_type 和 feature 组合的 OnnxConfig。
get_model_class_for_feature
< 源代码 >( feature: str framework: str = 'pt' )
尝试从功能名称检索 AutoModel 类。
get_model_from_feature
< 源代码 >( feature: str model: str framework: str = None cache_dir: str = None )
尝试从模型名称和要启用的功能检索模型。
get_supported_features_for_model_type
< 源代码 >( model_type: str model_name: Optional = None )
尝试从模型类型中检索特性 -> OnnxConfig 构造函数映射。