Optimum 文档
ONNX 导出配置类
并获得增强的文档体验
开始使用
ONNX 导出配置类
将模型导出到 ONNX 涉及到指定
- 输入名称。
- 输出名称。
- 动态轴。这些轴指的是在运行时可以动态更改的输入维度(例如,批次大小或序列长度)。所有其他轴将被视为静态轴,因此在运行时是固定的。
- 用于追踪模型的虚拟输入。这在 PyTorch 中是必需的,用于记录计算图并将其转换为 ONNX。
由于这些数据取决于模型和任务的选择,我们用配置类来表示它。每个配置类都与特定的模型架构相关联,并遵循命名约定 ArchitectureNameOnnxConfig
。例如,指定 BERT 模型的 ONNX 导出的配置是 BertOnnxConfig
。
由于许多架构在它们的 ONNX 配置方面共享相似的属性,🤗 Optimum 采用了 3 级类层次结构
- 抽象和通用的基类。这些类处理所有基本功能,同时与模态(文本、图像、音频等)无关。
- 中间层类。这些类了解模态,但对于相同的模态,可能存在多个类,具体取决于它们支持的输入。它们指定应该使用哪些输入生成器来生成虚拟输入,但仍然与模型无关。
- 特定于模型的类,如上面提到的
BertOnnxConfig
。这些类是实际用于导出模型的类。
基类
class optimum.exporters.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
用于 ONNX 可导出模型的基础类,描述了如何通过 ONNX 格式导出模型的元数据。
类属性
- NORMALIZED_CONFIG_CLASS (
Type
) — 从 NormalizedConfig 派生的类,指定如何规范化模型配置。 - DUMMY_INPUT_GENERATOR_CLASSES (
Tuple[Type]
) — 从 DummyInputGenerator 派生的类元组,指定如何创建虚拟输入。 - ATOL_FOR_VALIDATION (
Union[float, Dict[str, float]]
) — 一个浮点数或一个将任务名称映射到浮点数的字典,其中浮点数值表示模型转换验证期间要使用的绝对容差值。 - DEFAULT_ONNX_OPSET (
int
, 默认为 11) — 用于 ONNX 导出的默认 ONNX opset。 - MIN_TORCH_VERSION (
packaging.version.Version
, 默认为~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION
) — 支持将模型导出到 ONNX 的最低 torch 版本。 - MIN_TRANSFORMERS_VERSION (
packaging.version.Version
, 默认为~optimum.exporters.onnx.utils.TRANSFORMERS_MINIMUM_VERSION
— 支持将模型导出到 ONNX 的最低 transformers 版本。并非总是最新或准确。这更多是为了内部使用。 - PATCHING_SPECS (
Optional[List[PatchingSpec]]
, 默认为None
) — 指定在执行导出之前应修补哪些运算符/模块,以及如何修补。当某些运算符在 ONNX 中不受支持时,这很有用。
包含要提供给模型的输入张量的轴定义的字典。
outputs
< source >( ) → Dict[str, Dict[int, str]]
返回
Dict[str, Dict[int, str]]
每个输出名称到轴位置到轴符号名称的映射。
包含要提供给模型的输出张量的轴定义的字典。
generate_dummy_inputs
< source >( framework: str = 'pt' **kwargs ) → Dict
参数
- framework (
str
, 默认为"pt"
) — 创建虚拟输入所用的框架。 - batch_size (
int
, 默认为 2) — 虚拟输入中使用的批次大小。 - sequence_length (
int
, 默认为 16) — 虚拟输入中使用的序列长度。 - num_choices (
int
, 默认为 4) — 为多项选择任务提供的候选答案的数量。 - image_width (
int
, 默认为 64) — 视觉任务的虚拟输入中使用的宽度。 - image_height (
int
, 默认为 64) — 视觉任务的虚拟输入中使用的高度。 - num_channels (
int
, 默认为 3) — 视觉任务的虚拟输入中使用的通道数。 - feature_size (
int
, 默认为 80) — 音频任务的虚拟输入中使用的特征数量,如果输入不是原始音频。例如,这是 STFT bin 或 MEL bin 的数量。 - nb_max_frames (
int
, 默认为 3000) — 音频任务的虚拟输入中使用的最大帧数,如果输入不是原始音频。 - audio_sequence_length (
int
, 默认为 16000) — 音频任务的虚拟输入中使用的帧数,如果输入是原始音频。
返回
Dict
一个字典,将输入名称映射到正确框架格式的虚拟张量。
生成追踪模型所需的虚拟输入。如果未明确指定,则使用默认输入形状。
class optimum.exporters.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfig。用于处理仅解码器模型的 ONNX 配置的基类。
add_past_key_values
< source >( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )
考虑方向,使用 past_key_values 动态轴填充 input_or_outputs
映射。
class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfigWithPast。用于处理编码器-解码器模型的 ONNX 配置的基类。
with_behavior
< source >( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) → OnnxSeq2SeqConfigWithPast
创建当前 OnnxConfig 的副本,但具有不同的 ConfigBehavior
和 use_past
值。
中间层类
文本
class optimum.exporters.onnx.TextEncoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理基于编码器的文本架构。
class optimum.exporters.onnx.TextDecoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于解码器的文本架构。
class optimum.exporters.onnx.TextSeq2SeqOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于编码器-解码器的文本架构。
视觉
class optimum.exporters.onnx.config.VisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理视觉架构。
多模态
class optimum.exporters.onnx.config.TextAndVisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理多模态文本和视觉架构。