Optimum 文档

概述

您正在查看 主分支 版本,需要从源代码安装。如果您想要使用常规的 pip 安装,请查看最新的稳定版本(v1.23.1)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

概述

🤗 Optimum 提供了一个名为 BetterTransformer 的 API,它是一个标准 PyTorch Transformer API 的快速路径,通过稀疏性和融合内核(如 Flash Attention)来利用有趣的加速效果,从而在 CPU 和 GPU 上获得性能提升。目前,BetterTransformer 支持来自原生 nn.TransformerEncoderLayer 的快速路径,以及来自 torch.nn.functional.scaled_dot_product_attention 的 Flash Attention 和内存高效注意力。

快速入门

从 1.13 版本开始,PyTorch 发布了其标准 Transformer API 的快速路径的稳定版本,该版本为基于 Transformer 的模型提供了开箱即用的性能改进。您可以在大多数消费级设备上获得有趣的加速效果,包括 CPU、较旧和较新的 NVIDIA GPU 版本。您现在可以在 🤗 Optimum 与 Transformers 一起使用此功能,并将其用于 Hugging Face 生态系统中的主要模型。

在 2.0 版本中,PyTorch 将一个原生的缩放点积注意力运算符 (SDPA) 作为 torch.nn.functional 的一部分包含进来。此函数包含可根据输入和使用的硬件应用的几种实现。请参阅 官方文档 以获取更多信息,以及 这篇博文 以了解基准测试。

我们在 🤗 Optimum 中开箱即用地提供了与这些优化的集成,以便您可以转换任何受支持的 🤗 Transformers 模型,以便在相关时使用优化的路径和 scaled_dot_product_attention 函数。

PyTorch 原生的 `scaled_dot_product_attention` 正在逐渐被 [设置为默认并集成到 🤗 Transformers 中](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention)。对于在 Transformers 中支持 SDPA 的模型,我们不推荐使用 BetterTransformer,并建议您直接使用 Transformers 和 PyTorch 的最新版本,通过 SDPA 实现注意力优化(Flash Attention、内存高效注意力)。
PyTorch 原生的 `scaled_dot_product_attention` 运算符只有在没有提供 `attention_mask` 的情况下才能分派到 Flash Attention。

因此,在训练模式下,BetterTransformer 集成默认情况下 **放弃了掩码支持,并且只能用于不需要填充掩码进行批量训练的训练**。例如,掩码语言建模或因果语言建模就是这种情况。BetterTransformer 不适合在需要填充掩码的任务上进行模型微调。

在推理模式下,会保留填充掩码以确保正确性,因此只有在批量大小为 1 的情况下才能预期加速效果。

支持的模型

以下是支持的模型列表

如果您希望支持更多模型,请在 🤗 Optimum 中打开一个问题,或者如果您想自己添加,请查看 贡献指南

快速使用

为了使用 BetterTransformer API,只需运行以下命令

>>> from transformers import AutoModelForSequenceClassification
>>> from optimum.bettertransformer import BetterTransformer
>>> model_hf = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> model = BetterTransformer.transform(model_hf, keep_original_model=True)

如果您想用其 BetterTransformer 版本覆盖当前模型,可以保留 keep_original_model=False

有关 教程 部分的更多详细信息,以深入了解如何使用它,或者查看 Google Colab 演示

< > 更新 在 GitHub 上