Optimum 文档

添加对不支持架构的支持

您正在查看 main 版本,该版本需要从源代码安装。如果您想要常规 pip 安装,请查看最新的稳定版本 (v1.24.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

添加对不支持架构的支持

如果您希望导出库中尚未支持其架构的模型,则需要遵循以下主要步骤

  1. 实现自定义 ONNX 配置。
  2. TasksManager 中注册 ONNX 配置。
  3. 将模型导出为 ONNX。
  4. 验证原始模型和导出模型的输出。

在本节中,我们将研究 BERT 的实现方式,以展示每个步骤涉及的内容。

实现自定义 ONNX 配置

让我们从 ONNX 配置对象开始。我们提供了一个 3 级 类层次结构,并且为了添加对模型的支持,继承正确的中端类将是大多数时候的选择。如果您要添加一个处理以前从未见过的模态和/或情况的架构,您可能需要自己实现一个中端类。

实现自定义 ONNX 配置的一个好方法是查看 optimum/exporters/onnx/model_configs.py 文件中现有的配置实现。

此外,如果您尝试添加的架构与已支持的架构(例如,当已经支持 BERT 时添加对 ALBERT 的支持)非常相似,那么尝试简单地从该类继承可能会起作用。

当从中端类继承时,请查找处理与您尝试支持的模型相同的模态/类别的类。

示例:添加对 BERT 的支持

由于 BERT 是一个基于编码器的文本模型,因此其配置继承自中端类 TextEncoderOnnxConfig。在 optimum/exporters/onnx/model_configs.py

# This class is actually in optimum/exporters/onnx/config.py
class TextEncoderOnnxConfig(OnnxConfig):
    # Describes how to generate the dummy inputs.
    DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)

class BertOnnxConfig(TextEncoderOnnxConfig):
    # Specifies how to normalize the BertConfig, this is needed to access common attributes
    # during dummy input generation.
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    # Sets the absolute tolerance to when validating the exported ONNX model against the
    # reference model.
    ATOL_FOR_VALIDATION = 1e-4

    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
        else:
            dynamic_axis = {0: "batch_size", 1: "sequence_length"}
        return {
            "input_ids": dynamic_axis,
            "attention_mask": dynamic_axis,
            "token_type_ids": dynamic_axis,
        }

首先,让我们解释一下 TextEncoderOnnxConfig 是关于什么的。虽然大多数功能已经在 OnnxConfig 中实现,但这个类是与模态无关的,这意味着它不知道应该处理哪种类型的输入。输入生成的方式是通过 DUMMY_INPUT_GENERATOR_CLASSES 属性处理的,该属性是 DummyInputGenerator 的元组。在这里,我们通过指定 DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) 来创建一个继承自 OnnxConfig 的模态感知配置。

接下来是特定于模型的类 BertOnnxConfig。这里指定了两个类属性

  • NORMALIZED_CONFIG_CLASS:这必须是一个 NormalizedConfig,它基本上允许输入生成器以通用方式访问模型配置属性。
  • ATOL_FOR_VALIDATION:当针对原始模型验证导出的模型时使用,这是输出值差异的绝对可接受容差。

每个配置对象都必须实现 inputs 属性并返回一个映射,其中每个键对应于一个输入名称,每个值指示该输入中是动态的轴。对于 BERT,我们可以看到需要三个输入:input_idsattention_masktoken_type_ids。这些输入具有相同的形状 (batch_size, sequence_length)(除了 multiple-choice 任务),这就是为什么我们在配置中看到使用了相同的轴。

一旦您实现了 ONNX 配置,您可以通过提供基本模型的配置来实例化它,如下所示

>>> from transformers import AutoConfig
>>> from optimum.exporters.onnx.model_configs import BertOnnxConfig
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config = BertOnnxConfig(config)

生成的对象具有几个有用的属性。例如,您可以查看导出期间将使用的 ONNX 运算符集

>>> print(onnx_config.DEFAULT_ONNX_OPSET)
11

您还可以按如下方式查看与模型关联的输出

>>> print(onnx_config.outputs)
OrderedDict([('last_hidden_state', {0: 'batch_size', 1: 'sequence_length'})])

请注意,outputs 属性遵循与 inputs 相同的结构;它返回命名输出及其形状的 OrderedDict。输出结构与配置初始化的任务选择相关。默认情况下,ONNX 配置使用 default 任务初始化,该任务对应于导出使用 AutoModel 类加载的模型。如果您想为另一个任务导出模型,只需在初始化 ONNX 配置时为 task 参数提供不同的任务即可。例如,如果我们希望导出带有序列分类头的 BERT,我们可以使用

>>> from transformers import AutoConfig

>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="text-classification")
>>> print(onnx_config_for_seq_clf.outputs)
OrderedDict([('logits', {0: 'batch_size'})])

查看 BartOnnxConfig 以获取高级示例。

在 TasksManager 中注册 ONNX 配置

TasksManager 是加载给定名称和任务的模型的以及获取给定(架构、后端)对的正确配置的主要入口点。当添加对导出到 ONNX 的支持时,将配置注册到 TasksManager 将使导出在命令行工具中可用。

为此,请在 _SUPPORTED_MODEL_TYPE 属性中添加一个条目

  • 如果该模型已经支持 ONNX 以外的其他后端,它将已经有一个条目,因此您只需要添加一个 onnx 键,指定配置类的名称。
  • 否则,您将必须添加整个条目。

对于 BERT,它看起来如下所示

    "bert": supported_tasks_mapping(
        "default",
        "fill-mask",
        "text-generation",
        "text-classification",
        "multiple-choice",
        "token-classification",
        "question-answering",
        onnx="BertOnnxConfig",
    )

导出模型

一旦您实现了 ONNX 配置,下一步就是导出模型。在这里,我们可以使用 optimum.exporters.onnx 包提供的 export() 函数。此函数需要 ONNX 配置,以及基本模型和保存导出文件的路径

>>> from pathlib import Path
>>> from optimum.exporters import TasksManager
>>> from optimum.exporters.onnx import export
>>> from transformers import AutoModel

>>> base_model = AutoModel.from_pretrained("bert-base-uncased")

>>> onnx_path = Path("model.onnx")
>>> onnx_config_constructor = TasksManager.get_exporter_config_constructor("onnx", base_model)
>>> onnx_config = onnx_config_constructor(base_model.config)

>>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_path, onnx_config.DEFAULT_ONNX_OPSET)

export() 函数返回的 onnx_inputsonnx_outputsinputsinputs 配置属性中定义的键的列表。模型导出后,您可以按如下方式测试模型是否格式正确

>>> import onnx

>>> onnx_model = onnx.load("model.onnx")
>>> onnx.checker.check_model(onnx_model)

如果您的模型大于 2GB,您将看到在导出期间创建了许多附加文件。这是预期的,因为 ONNX 使用 Protocol Buffers 来存储模型,而这些模型的大小限制为 2GB。有关如何加载具有外部数据的模型的说明,请参阅 ONNX 文档

验证模型输出

最后一步是验证来自基本模型和导出模型的输出在某个绝对容差范围内一致。在这里,我们可以使用 optimum.exporters.onnx 包提供的 validate_model_outputs() 函数

>>> from optimum.exporters.onnx import validate_model_outputs

>>> validate_model_outputs(
...     onnx_config, base_model, onnx_path, onnx_outputs, onnx_config.ATOL_FOR_VALIDATION
... )

向 🤗 Optimum 贡献新配置

现在已经实现了对架构的支持并进行了验证,还剩下两件事

  1. 将您的模型架构添加到 tests/exporters/test_onnx_export.py 中的测试
  2. optimum 仓库 上创建 PR

感谢您的贡献!

< > 更新 在 GitHub 上