Optimum 文档

ONNX 导出配置类

您正在查看的是需要从源码安装。如果你想通过常规的 pip 安装,请查看最新的稳定版本(v1.27.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

ONNX 导出配置类

将模型导出为 ONNX 涉及指定

  1. 输入名称。
  2. 输出名称。
  3. 动态轴。这些指的是运行时可以动态更改的输入维度(例如,批次大小或序列长度)。所有其他轴将被视为静态轴,因此在运行时固定。
  4. 用于跟踪模型的虚拟输入。在 PyTorch 中需要此操作来记录计算图并将其转换为 ONNX。

由于此数据取决于模型和任务的选择,我们将其表示为*配置类*。每个配置类都与特定的模型架构相关联,并遵循命名约定 `ArchitectureNameOnnxConfig`。例如,指定 BERT 模型 ONNX 导出的配置是 `BertOnnxConfig`。

由于许多架构在 ONNX 配置上共享相似的属性,🤗 Optimum 采用了 3 级类层次结构

  1. 抽象和通用的基类。这些类处理所有基本功能,同时与模态(文本、图像、音频等)无关。
  2. 中端类。这些类知道模态,但同一模态可以存在多个,具体取决于它们支持的输入。它们指定应使用哪些输入生成器来生成虚拟输入,但与模型无关。
  3. 特定于模型的类,例如上面提到的 `BertOnnxConfig`。这些是实际用于导出模型的类。

基类

class optimum.exporters.onnx.OnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

输入

< >

( ) Dict[str, Dict[int, str]]

返回

Dict[str, Dict[int, str]]

将每个输入名称映射到轴位置与轴符号名称的映射。

包含要提供给模型的输入张量轴定义的字典。

输出

< >

( ) Dict[str, Dict[int, str]]

返回

Dict[str, Dict[int, str]]

将每个输出名称映射到轴位置与轴符号名称的映射。

包含要提供给模型的输出张量轴定义的字典。

生成虚拟输入

< >

( framework: str = 'pt' **kwargs ) Dict[str, [tf.Tensor, torch.Tensor]]

参数

  • framework (str, 默认为 "pt") — 用于创建虚拟输入的框架。
  • batch_size (int, 默认为 2) — 用于虚拟输入的批次大小。
  • sequence_length (int, 默认为 16) — 用于虚拟输入的序列长度。
  • num_choices (int, 默认为 4) — 多项选择任务的候选答案数量。
  • image_width (int, 默认为 64) — 用于视觉任务虚拟输入的宽度。
  • image_height (int, 默认为 64) — 用于视觉任务虚拟输入的高度。
  • num_channels (int, 默认为 3) — 用于视觉任务虚拟输入的通道数。
  • feature_size (int, 默认为 80) — 用于音频任务虚拟输入的特征数量,如果不是原始音频。例如,这是 STFT bin 或 MEL bin 的数量。
  • nb_max_frames (int, 默认为 3000) — 用于音频任务虚拟输入的帧数,如果输入不是原始音频。
  • audio_sequence_length (int, 默认为 16000) — 用于音频任务虚拟输入的帧数,如果输入是原始音频。

返回

Dict[str, [tf.Tensor, torch.Tensor]]

将输入名称映射到正确框架格式的虚拟张量的字典。

生成跟踪模型所需的虚拟输入。如果未明确指定,则使用默认输入形状。

class optimum.exporters.onnx.OnnxConfigWithPast

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

继承自 OnnxConfig。处理仅解码器模型 ONNX 配置的基类。

add_past_key_values

< >

( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )

参数

  • inputs_or_outputs (Dict[str, Dict[int, str]]) — 要填充的映射。
  • direction (str) — “inputs” 或 “outputs”,它指定 `input_or_outputs` 是输入映射还是输出映射,这对于轴命名很重要。

根据方向使用 past_key_values 动态轴填充 `input_or_outputs` 映射。

class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

继承自 OnnxConfigWithPast。处理编码器-解码器模型 ONNX 配置的基类。

with_behavior

< >

( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) OnnxSeq2SeqConfigWithPast

参数

  • behavior (ConfigBehavior) — 用于新实例的行为。
  • use_past (bool, 默认为 False) — 要实例化的 ONNX 配置是否适用于使用 KV 缓存的模型。
  • use_past_in_inputs (bool, 默认为 False) — KV 缓存是否作为输入传递给 ONNX。

返回

OnnxSeq2SeqConfigWithPast

创建当前 OnnxConfig 的副本,但具有不同的 `ConfigBehavior` 和 `use_past` 值。

中端类

文本

class optimum.exporters.onnx.TextEncoderOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理基于编码器的文本架构。

class optimum.exporters.onnx.TextDecoderOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

处理基于解码器的文本架构。

class optimum.exporters.onnx.TextSeq2SeqOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )

处理基于编码器-解码器的文本架构。

视觉

class optimum.exporters.onnx.config.VisionOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理视觉架构。

多模态

class optimum.exporters.onnx.config.TextAndVisionOnnxConfig

< >

( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )

处理多模态文本和视觉架构。

< > 在 GitHub 上更新