Optimum 文档

优化

您正在查看 主分支 版本,需要从源代码安装。如果您希望使用常规的 pip 安装,请查看最新稳定版本(v1.23.1)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

优化

optimum.fx.optimization 模块提供了一组 torch.fx 图变换,以及用于编写自己的变换并将其组合在一起的类和函数。

变换指南

在 🤗 Optimum 中,有两种类型的变换:可逆变换和不可逆变换。

编写不可逆变换

变换最基本的情况是不可逆变换。这些变换无法逆转,这意味着在将它们应用于图模块后,无法恢复原始模型。在 🤗 Optimum 中实现此类变换非常简单:您只需要子类化 Transformation 并实现 transform() 方法。

例如,以下变换将所有乘法更改为加法

>>> import operator
>>> from optimum.fx.optimization import Transformation

>>> class ChangeMulToAdd(Transformation):
...     def transform(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 node.target = operator.add
...         return graph_module

实现后,您的变换可以用作常规函数

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
...     model,
...     input_names=["input_ids", "attention_mask", "token_type_ids"],
... )

>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)

编写可逆变换

可逆变换同时实现了变换及其逆变换,允许从变换后的模型中检索原始模型。要实现此类变换,您需要子类化 ReversibleTransformation 并实现 transform()reverse() 方法。

例如,以下变换是可逆的

>>> import operator
>>> from optimum.fx.optimization import ReversibleTransformation

>>> class MulToMulTimesTwo(ReversibleTransformation):
...     def transform(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 x, y = node.args
...                 node.args = (2 * x, y)
...         return graph_module
...
...     def reverse(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 x, y = node.args
...                 node.args = (x / 2, y)
...         return graph_module

组合变换

由于需要多次链接应用变换的情况并不少见,因此提供了 compose()。这是一个实用程序函数,允许您通过链接多个其他变换来创建变换。

>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())
< > 在 GitHub 上更新