Optimum 文档

如何使用 ONNX Runtime 加速训练

您正在查看 主分支 版本,需要从源代码安装. 如果您想使用常规的 pip 安装,请查看最新的稳定版本 (v1.23.1).
Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

如何使用 ONNX Runtime 加速训练

Optimum 通过一个 ORTTrainer API 集成了 ONNX Runtime 训练,该 API 扩展了 Transformers 中的 Trainer。使用此扩展,与 PyTorch 在急切模式下相比,许多流行的 Hugging Face 模型的训练时间可以减少 35% 以上。

ORTTrainerORTSeq2SeqTrainer API 使得将 ONNX Runtime (ORT)Trainer 中的其他功能组合起来变得容易。它包含功能完整的训练循环和评估循环,并支持超参数搜索、混合精度训练以及使用多个 NVIDIAAMD GPU 进行分布式训练。借助 ONNX Runtime 后端,ORTTrainerORTSeq2SeqTrainer 利用了

  • 计算图优化:常量折叠、节点消除、节点融合
  • 高效的内存规划
  • 内核优化
  • ORT 融合的 Adam 优化器:将应用于模型所有参数的逐元素更新批处理到一个或几个内核启动中
  • 更有效的 FP16 优化器:消除了大量设备到主机内存的复制
  • 混合精度训练

试试看,在 🤗 Transformers 中训练模型时实现 更低的延迟、更高的吞吐量和更大的最大批次大小

性能

下图显示了 Hugging Face 模型使用 Optimum 时在 使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 进行训练时,从 39% 到 130% 的显著加速。性能测量是在选定的 Hugging Face 模型上完成的,其中 PyTorch 作为基准运行,仅 ONNX Runtime 用于训练作为第二次运行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作为最终运行,显示最大增益。基准 PyTorch 运行使用的优化器是 AdamW 优化器,而 ORT 训练运行使用的是 Fused Adam 优化器(在 ORTTrainingArguments 中可用)。这些运行在一个具有 8 个 GPU 的单个 Nvidia A100 节点上执行。

ONNX Runtime Training Benchmark

这些运行使用的版本信息如下

PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2

从设置环境开始

要使用 ONNX Runtime 进行训练,您需要一台至少配备一个 NVIDIA 或 AMD GPU 的机器。

要使用 ORTTrainerORTSeq2SeqTrainer,您需要安装 ONNX Runtime 训练模块和 Optimum。

安装 ONNX Runtime

要设置环境,我们 强烈建议 您使用 Docker 安装依赖项,以确保版本正确且配置良好。您可以在 此处 找到包含各种组合的 Dockerfile。

NVIDIA GPU 设置

下面我们以安装 onnxruntime-training 1.14.0 为例

  • 如果您想通过Dockerfile安装onnxruntime-training 1.14.0
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
  • 如果您想在本地 Python 环境中安装除上述之外的依赖项,您可以在安装好CUDA 11.6cuDNN 8 后使用 pip 安装它们。
pip install onnx ninja
pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html
pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2

并运行安装后配置

python -m torch_ort.configure

AMD GPU 设置

下面我们将以安装onnxruntime-training nightly 版本为例。

  • 如果您想通过Dockerfile安装onnxruntime-training
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
  • 如果您想在本地 Python 环境中安装除上述之外的依赖项,您可以在安装好ROCM 5.7 后使用 pip 安装它们。
pip install onnx ninja
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2

并运行安装后配置

python -m torch_ort.configure

安装 Optimum

您可以通过 pypi 安装 Optimum

pip install optimum

或从源代码安装

pip install git+https://github.com/huggingface/optimum.git

此命令安装 Optimum 的当前主开发版本,其中可能包含最新的开发成果(新功能、错误修复)。但是,主版本可能不稳定。如果您遇到任何问题,请打开一个问题,以便我们尽快解决。

ORTTrainer

ORTTrainer 类继承了 Transformers 的Trainer。您可以轻松地通过将 Transformers 的 Trainer 替换为 ORTTrainer 来调整代码,以利用 ONNX Runtime 加速。以下是如何使用 ORTTrainerTrainer 相比的示例。

-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments

# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text-classification",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()

在 optimum 存储库中查看更多详细的示例脚本

ORTSeq2SeqTrainer

ORTSeq2SeqTrainer 类类似于 Transformers 的Seq2SeqTrainer。您可以轻松地通过将 Transformers 的 Seq2SeqTrainer 替换为 ORTSeq2SeqTrainer 来调整代码,以利用 ONNX Runtime 加速。以下是如何使用 ORTSeq2SeqTrainerSeq2SeqTrainer 相比的示例。

-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments

# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text2text-generation",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()

在 optimum 存储库中查看更多详细的示例脚本

ORTTrainingArguments

ORTTrainingArguments 类继承了 Transformers 中的TrainingArguments 类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments 替换为 ORTSeq2SeqTrainingArguments

-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments

-training_args = TrainingArguments(
+training_args =  ORTTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)

ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些DeepSpeed 配置示例

ORTSeq2SeqTrainingArguments

ORTSeq2SeqTrainingArguments 类继承了 Transformers 中的Seq2SeqTrainingArguments 类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments 替换为 ORTSeq2SeqTrainingArguments

-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments

-training_args = Seq2SeqTrainingArguments(
+training_args =  ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)

ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些DeepSpeed 配置示例

ORTModule+StableDiffusion

Optimum 支持使用 ONNX Runtime 在此示例中加速 Hugging Face Diffusers。启用 ONNX Runtime 训练所需的核心更改总结如下

import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel

+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer

unet = UNet2DConditionModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="unet",
    ...
)
text_encoder = CLIPTextModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="text_encoder",
    ...
)
vae = AutoencoderKL.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="vae",
    ...
)

optimizer = torch.optim.AdamW(
    unet.parameters(),
    ...
)

+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)

+optimizer = ORT_FP16_Optimizer(optimizer)

其他资源

如果您在使用 ORTTrainer 时遇到任何问题或疑问,请在Optimum Github 上提交问题,或在HuggingFace 社区论坛 上与我们讨论,🤗 感谢您的支持!

< > 更新 在 GitHub 上