Optimum 文档

ONNX 导出配置类

您正在查看 main 版本,该版本需要从源码安装。如果您想要常规的 pip 安装,请查看最新的稳定版本 (v1.24.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

ONNX 导出配置类

将模型导出到 ONNX 涉及到指定

  1. 输入名称。
  2. 输出名称。
  3. 动态轴。这些轴指的是在运行时可以动态更改的输入维度(例如,批次大小或序列长度)。所有其他轴将被视为静态轴,因此在运行时是固定的。
  4. 用于追踪模型的虚拟输入。这在 PyTorch 中是必需的,用于记录计算图并将其转换为 ONNX。

由于这些数据取决于模型和任务的选择,我们用配置类来表示它。每个配置类都与特定的模型架构相关联,并遵循命名约定 ArchitectureNameOnnxConfig。例如,指定 BERT 模型的 ONNX 导出的配置是 BertOnnxConfig

由于许多架构在它们的 ONNX 配置方面共享相似的属性,🤗 Optimum 采用了 3 级类层次结构

  1. 抽象和通用的基类。这些类处理所有基本功能,同时与模态(文本、图像、音频等)无关。
  2. 中间层类。这些类了解模态,但对于相同的模态,可能存在多个类,具体取决于它们支持的输入。它们指定应该使用哪些输入生成器来生成虚拟输入,但仍然与模型无关。
  3. 特定于模型的类,如上面提到的 BertOnnxConfig。这些类是实际用于导出模型的类。

基类

class optimum.exporters.onnx.OnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

参数

  • config (transformers.PretrainedConfig) — 模型配置。
  • task (str, 默认为 "feature-extraction") — 模型应该导出的任务。
  • int_dtype (str, 默认为 "int64") — 整数张量的数据类型,可以是 [“int64”, “int32”, “int8”],默认为 “int64”。
  • float_dtype (str, 默认为 "fp32") — 浮点张量的数据类型,可以是 [“fp32”, “fp16”, “bf16”],默认为 “fp32”。

用于 ONNX 可导出模型的基础类,描述了如何通过 ONNX 格式导出模型的元数据。

类属性

  • NORMALIZED_CONFIG_CLASS (Type) — 从 NormalizedConfig 派生的类,指定如何规范化模型配置。
  • DUMMY_INPUT_GENERATOR_CLASSES (Tuple[Type]) — 从 DummyInputGenerator 派生的类元组,指定如何创建虚拟输入。
  • ATOL_FOR_VALIDATION (Union[float, Dict[str, float]]) — 一个浮点数或一个将任务名称映射到浮点数的字典,其中浮点数值表示模型转换验证期间要使用的绝对容差值。
  • DEFAULT_ONNX_OPSET (int, 默认为 11) — 用于 ONNX 导出的默认 ONNX opset。
  • MIN_TORCH_VERSION (packaging.version.Version, 默认为 ~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION) — 支持将模型导出到 ONNX 的最低 torch 版本。
  • MIN_TRANSFORMERS_VERSION (packaging.version.Version, 默认为 ~optimum.exporters.onnx.utils.TRANSFORMERS_MINIMUM_VERSION — 支持将模型导出到 ONNX 的最低 transformers 版本。并非总是最新或准确。这更多是为了内部使用。
  • PATCHING_SPECS (Optional[List[PatchingSpec]], 默认为 None) — 指定在执行导出之前应修补哪些运算符/模块,以及如何修补。当某些运算符在 ONNX 中不受支持时,这很有用。

inputs

< >

( ) Dict[str, Dict[int, str]]

返回

Dict[str, Dict[int, str]]

每个输入名称到轴位置到轴符号名称的映射。

包含要提供给模型的输入张量的轴定义的字典。

outputs

< >

( ) Dict[str, Dict[int, str]]

返回

Dict[str, Dict[int, str]]

每个输出名称到轴位置到轴符号名称的映射。

包含要提供给模型的输出张量的轴定义的字典。

generate_dummy_inputs

< >

( framework: str = 'pt' **kwargs ) Dict

参数

  • framework (str, 默认为 "pt") — 创建虚拟输入所用的框架。
  • batch_size (int, 默认为 2) — 虚拟输入中使用的批次大小。
  • sequence_length (int, 默认为 16) — 虚拟输入中使用的序列长度。
  • num_choices (int, 默认为 4) — 为多项选择任务提供的候选答案的数量。
  • image_width (int, 默认为 64) — 视觉任务的虚拟输入中使用的宽度。
  • image_height (int, 默认为 64) — 视觉任务的虚拟输入中使用的高度。
  • num_channels (int, 默认为 3) — 视觉任务的虚拟输入中使用的通道数。
  • feature_size (int, 默认为 80) — 音频任务的虚拟输入中使用的特征数量,如果输入不是原始音频。例如,这是 STFT bin 或 MEL bin 的数量。
  • nb_max_frames (int, 默认为 3000) — 音频任务的虚拟输入中使用的最大帧数,如果输入不是原始音频。
  • audio_sequence_length (int, 默认为 16000) — 音频任务的虚拟输入中使用的帧数,如果输入是原始音频。

返回

Dict

一个字典,将输入名称映射到正确框架格式的虚拟张量。

生成追踪模型所需的虚拟输入。如果未明确指定,则使用默认输入形状。

class optimum.exporters.onnx.OnnxConfigWithPast

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

继承自 OnnxConfig。用于处理仅解码器模型的 ONNX 配置的基类。

add_past_key_values

< >

( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )

参数

  • inputs_or_outputs (Dict[str, Dict[int, str]]) — 要填充的映射。
  • direction (str) — “inputs” 或 “outputs” 之一,它指定了 input_or_outputs 是输入映射还是输出映射,这对于轴命名非常重要。

考虑方向,使用 past_key_values 动态轴填充 input_or_outputs 映射。

class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

继承自 OnnxConfigWithPast。用于处理编码器-解码器模型的 ONNX 配置的基类。

with_behavior

< >

( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) OnnxSeq2SeqConfigWithPast

参数

  • behavior (ConfigBehavior) — 用于新实例的行为。
  • use_past (bool, 默认为 False) — 实例化的 ONNX 配置是否用于使用 KV 缓存的模型。
  • use_past_in_inputs (bool, 默认为 False) — KV 缓存是否将作为 ONNX 的输入传递。

返回

OnnxSeq2SeqConfigWithPast

创建当前 OnnxConfig 的副本,但具有不同的 ConfigBehavioruse_past 值。

中间层类

文本

class optimum.exporters.onnx.TextEncoderOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理基于编码器的文本架构。

class optimum.exporters.onnx.TextDecoderOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

处理基于解码器的文本架构。

class optimum.exporters.onnx.TextSeq2SeqOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

处理基于编码器-解码器的文本架构。

视觉

class optimum.exporters.onnx.config.VisionOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理视觉架构。

多模态

class optimum.exporters.onnx.config.TextAndVisionOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理多模态文本和视觉架构。

< > 在 GitHub 上更新