Optimum 文档
ONNX 导出配置类
并获得增强的文档体验
开始使用
ONNX 导出配置类
将模型导出为 ONNX 涉及指定
- 输入名称。
- 输出名称。
- 动态轴。这些指的是运行时可以动态更改的输入维度(例如,批次大小或序列长度)。所有其他轴将被视为静态轴,因此在运行时固定。
- 用于跟踪模型的虚拟输入。在 PyTorch 中需要此操作来记录计算图并将其转换为 ONNX。
由于此数据取决于模型和任务的选择,我们将其表示为*配置类*。每个配置类都与特定的模型架构相关联,并遵循命名约定 `ArchitectureNameOnnxConfig`。例如,指定 BERT 模型 ONNX 导出的配置是 `BertOnnxConfig`。
由于许多架构在 ONNX 配置上共享相似的属性,🤗 Optimum 采用了 3 级类层次结构
- 抽象和通用的基类。这些类处理所有基本功能,同时与模态(文本、图像、音频等)无关。
- 中端类。这些类知道模态,但同一模态可以存在多个,具体取决于它们支持的输入。它们指定应使用哪些输入生成器来生成虚拟输入,但与模型无关。
- 特定于模型的类,例如上面提到的 `BertOnnxConfig`。这些是实际用于导出模型的类。
基类
class optimum.exporters.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
包含要提供给模型的输入张量轴定义的字典。
包含要提供给模型的输出张量轴定义的字典。
生成虚拟输入
< source >( framework: str = 'pt' **kwargs ) → Dict[str, [tf.Tensor, torch.Tensor]]
参数
- framework (
str
, 默认为"pt"
) — 用于创建虚拟输入的框架。 - batch_size (
int
, 默认为 2) — 用于虚拟输入的批次大小。 - sequence_length (
int
, 默认为 16) — 用于虚拟输入的序列长度。 - num_choices (
int
, 默认为 4) — 多项选择任务的候选答案数量。 - image_width (
int
, 默认为 64) — 用于视觉任务虚拟输入的宽度。 - image_height (
int
, 默认为 64) — 用于视觉任务虚拟输入的高度。 - num_channels (
int
, 默认为 3) — 用于视觉任务虚拟输入的通道数。 - feature_size (
int
, 默认为 80) — 用于音频任务虚拟输入的特征数量,如果不是原始音频。例如,这是 STFT bin 或 MEL bin 的数量。 - nb_max_frames (
int
, 默认为 3000) — 用于音频任务虚拟输入的帧数,如果输入不是原始音频。 - audio_sequence_length (
int
, 默认为 16000) — 用于音频任务虚拟输入的帧数,如果输入是原始音频。
返回
Dict[str, [tf.Tensor, torch.Tensor]]
将输入名称映射到正确框架格式的虚拟张量的字典。
生成跟踪模型所需的虚拟输入。如果未明确指定,则使用默认输入形状。
class optimum.exporters.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfig。处理仅解码器模型 ONNX 配置的基类。
add_past_key_values
< source >( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )
根据方向使用 past_key_values 动态轴填充 `input_or_outputs` 映射。
class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
继承自 OnnxConfigWithPast。处理编码器-解码器模型 ONNX 配置的基类。
with_behavior
< source >( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) → OnnxSeq2SeqConfigWithPast
创建当前 OnnxConfig 的副本,但具有不同的 `ConfigBehavior` 和 `use_past` 值。
中端类
文本
class optimum.exporters.onnx.TextEncoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理基于编码器的文本架构。
class optimum.exporters.onnx.TextDecoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于解码器的文本架构。
class optimum.exporters.onnx.TextSeq2SeqOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
处理基于编码器-解码器的文本架构。
视觉
class optimum.exporters.onnx.config.VisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理视觉架构。
多模态
class optimum.exporters.onnx.config.TextAndVisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
处理多模态文本和视觉架构。