Optimum 文档
优化
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
优化
optimum.fx.optimization
模块提供了一组 torch.fx 图转换,以及用于编写您自己的转换和组合它们的功能和类。
转换指南
在 🤗 Optimum 中,有两种类型的转换:可逆转换和不可逆转换。
编写不可逆转换
最基本的转换情况是不可逆转换。这些转换是不可逆的,这意味着在将它们应用于图模块后,无法找回原始模型。要在 🤗 Optimum 中实现此类转换非常容易:您只需要继承 Transformation 并实现 transform() 方法。
例如,以下转换将所有乘法更改为加法
>>> import operator
>>> from optimum.fx.optimization import Transformation
>>> class ChangeMulToAdd(Transformation):
... def transform(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... node.target = operator.add
... return graph_module
实现后,您的转换可以像常规函数一样使用
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)
编写可逆转换
可逆转换同时实现了转换及其逆转换,从而可以从转换后的模型中检索原始模型。要实现这种转换,您需要继承 ReversibleTransformation 并实现 transform() 和 reverse() 方法。
例如,以下转换是可逆的
>>> import operator
>>> from optimum.fx.optimization import ReversibleTransformation
>>> class MulToMulTimesTwo(ReversibleTransformation):
... def transform(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... x, y = node.args
... node.args = (2 * x, y)
... return graph_module
...
... def reverse(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... x, y = node.args
... node.args = (x / 2, y)
... return graph_module
组合转换
由于通常需要链式应用多个转换,因此提供了 compose()。它是一个实用程序函数,允许您通过链接多个其他转换来创建转换。
>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())