用于 ONNX 导出的配置类
将模型导出到 ONNX 需要指定
- 输入名称。
- 输出名称。
- 动态轴。这些指的是在运行时可以动态更改的输入维度(例如批大小或序列长度)。所有其他轴将被视为静态,因此在运行时固定。
- 用于跟踪模型的虚拟输入。这在 PyTorch 中是必需的,用于记录计算图并将其转换为 ONNX。
由于此数据取决于模型和任务的选择,因此我们使用 *配置类* 来表示它。每个配置类都与特定的模型架构相关联,并遵循命名约定 ArchitectureNameOnnxConfig
。例如,指定 BERT 模型的 ONNX 导出的配置是 BertOnnxConfig
。
由于许多架构在其 ONNX 配置方面共享类似的属性,🤗 Optimum 采用了一个三级类层次结构
- 抽象和通用的基类。这些处理所有基本功能,同时与模态(文本、图像、音频等)无关。
- 中间类。这些类了解模态,但多个类可以存在于同一模态中,具体取决于它们支持的输入。它们指定应使用哪些输入生成器作为虚拟输入,但保持与模型无关。
- 模型特定的类,例如上面提到的
BertOnnxConfig
。这些是实际用于导出模型的类。
基类
class optimum.exporters.onnx.OnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
用于 ONNX 可导出模型的基类,描述了如何通过 ONNX 格式导出模型的元数据。
类属性
- NORMALIZED_CONFIG_CLASS (
Type
) — 从 NormalizedConfig 派生的类,指定如何规范化模型配置。 - DUMMY_INPUT_GENERATOR_CLASSES (
Tuple[Type]
) — 从 DummyInputGenerator 派生的类的元组,指定如何创建虚拟输入。 - ATOL_FOR_VALIDATION (
Union[float, Dict[str, float]]
) — 一个浮点数或一个字典,将任务名称映射到浮点数,其中浮点值表示模型转换验证期间要使用的绝对容差值。 - DEFAULT_ONNX_OPSET (
int
, defaults to 11) — 用于 ONNX 导出的默认 ONNX 操作集。 - MIN_TORCH_VERSION (
packaging.version.Version
, defaults to~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION
) — 支持将模型导出到 ONNX 的最小 torch 版本。 - MIN_TRANSFORMERS_VERSION (
packaging.version.Version
, defaults to~optimum.exporters.onnx.utils.TRANSFORMERS_MINIMUM_VERSION
— 支持将模型导出到 ONNX 的最小 transformers 版本。 不总是最新或准确的。 这更多用于内部使用。 - PATCHING_SPECS (
Optional[List[PatchingSpec]]
, defaults toNone
) — 指定在执行导出之前应修补哪些运算符/模块,以及如何修补。 例如,当某些运算符在 ONNX 中不受支持时,这很有用。
包含要提供给模型的输入张量的轴定义的字典。
outputs
< source >( ) → Dict[str, Dict[int, str]]
返回
Dict[str, Dict[int, str]]
每个输出名称到轴位置到轴符号名称的映射。
包含要提供给模型的输出张量的轴定义的字典。
generate_dummy_inputs
< source >( framework: str = 'pt' **kwargs ) → Dict
参数
- framework (
str
, defaults to"pt"
) — 用于创建虚拟输入的框架。 - batch_size (
int
, defaults to 2) — 虚拟输入中要使用的批次大小。 - sequence_length (
int
, defaults to 16) — 虚拟输入中要使用的序列长度。 - num_choices (
int
, defaults to 4) — 多项选择任务提供的候选答案数量。 - image_width (
int
, defaults to 64) — 视觉任务中虚拟输入中要使用的宽度。 - num_channels (
int
, 默认为 3) — 用于视觉任务的虚拟输入的通道数量。 - feature_size (
int
, 默认为 80) — 如果音频任务的输入不是原始音频,则用于虚拟输入的特征数量。 例如,STFT 频段或 MEL 频段的数量。 - nb_max_frames (
int
, 默认为 3000) — 如果音频任务的输入不是原始音频,则用于虚拟输入的帧数。 - audio_sequence_length (
int
, 默认为 16000) — 如果音频任务的输入是原始音频,则用于虚拟输入的帧数。
返回
字典
一个字典,将输入名称映射到适当框架格式的虚拟张量。
生成跟踪模型所需的虚拟输入。 如果未明确指定,则使用默认输入形状。
类 optimum.exporters.onnx.OnnxConfigWithPast
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfig。 用于处理仅解码器模型的 ONNX 配置的基类。
add_past_key_values
< 源代码 >( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )
使用过去关键值动态轴填充 input_or_outputs
映射,考虑方向。
类 optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfigWithPast。用于处理编码器-解码器模型的 ONNX 配置的基类。
with_behavior
< 源代码 >( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) → OnnxSeq2SeqConfigWithPast
创建一个当前 OnnxConfig 的副本,但具有不同的 ConfigBehavior
和 use_past
值。
中间端类
文本
类 optimum.exporters.onnx.TextEncoderOnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理基于编码器的文本架构。
类 optimum.exporters.onnx.TextDecoderOnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于解码器的文本架构。
类 optimum.exporters.onnx.TextSeq2SeqOnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于编码器-解码器的文本架构。
视觉
类 optimum.exporters.onnx.config.VisionOnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理视觉架构。
多模态
类 optimum.exporters.onnx.config.TextAndVisionOnnxConfig
< 源代码 >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理多模态文本和视觉架构。