Optimum 文档

如何使用 ONNX Runtime 加速训练

您正在查看 main 版本,该版本需要从源代码安装。如果您想要常规的 pip 安装,请查看最新的稳定版本 (v1.24.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

如何使用 ONNX Runtime 加速训练

Optimum 通过 ORTTrainer API 集成了 ONNX Runtime Training,该 API 扩展了 Transformers 中的 Trainer。通过此扩展,与 eager 模式下的 PyTorch 相比,许多流行的 Hugging Face 模型的训练时间可以减少 35% 以上。

ORTTrainerORTSeq2SeqTrainer API 可以轻松地将 ONNX Runtime (ORT)Trainer 中的其他功能结合使用。它包含功能完整的训练循环和评估循环,并支持超参数搜索、混合精度训练以及使用多个 NVIDIAAMD GPU 的分布式训练。借助 ONNX Runtime 后端,ORTTrainerORTSeq2SeqTrainer 可以利用:

  • 计算图优化:常量折叠、节点消除、节点融合
  • 高效的内存规划
  • 内核优化
  • ORT 融合 Adam 优化器:将应用于模型所有参数的逐元素更新批处理到一个或几个内核启动中
  • 更高效的 FP16 优化器:消除大量设备到主机的内存复制
  • 混合精度训练

尝试一下,在 🤗 Transformers 中训练模型时,实现更低的延迟、更高的吞吐量和更大的最大批处理大小

性能

下图显示了使用 Optimum 的 Hugging Face 模型在使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 进行训练时,令人印象深刻的加速效果,从 39% 到 130% 不等。性能测量是在选定的 Hugging Face 模型上完成的,PyTorch 作为基线运行,仅 ONNX Runtime 用于训练作为第二次运行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作为最终运行,显示了最大增益。基线 PyTorch 运行使用的优化器是 AdamW 优化器,ORT Training 运行使用融合 Adam 优化器(在 ORTTrainingArguments 中可用)。运行是在具有 8 个 GPU 的单个 Nvidia A100 节点上执行的。

ONNX Runtime Training Benchmark

这些运行使用的版本信息如下

PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2

开始设置环境

要使用 ONNX Runtime 进行训练,您需要一台至少配备一个 NVIDIA 或 AMD GPU 的机器。

要使用 ORTTrainerORTSeq2SeqTrainer,您需要安装 ONNX Runtime Training 模块和 Optimum。

安装 ONNX Runtime

要设置环境,我们强烈建议您使用 Docker 安装依赖项,以确保版本正确且配置良好。您可以在此处找到包含各种组合的 dockerfiles。

NVIDIA GPU 设置

下面我们以安装 onnxruntime-training 1.14.0 为例

  • 如果您想通过 Dockerfile 安装 onnxruntime-training 1.14.0
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
  • 如果您想在本地 Python 环境中安装依赖项之外的其他内容。一旦您正确安装了 CUDA 11.6cuDNN 8,就可以使用 pip 安装它们。
pip install onnx ninja
pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html
pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2

并运行安装后配置

python -m torch_ort.configure

AMD GPU 设置

下面我们以安装 onnxruntime-training nightly 版本为例

  • 如果您想通过 Dockerfile 安装 onnxruntime-training
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
  • 如果您想在本地 Python 环境中安装依赖项之外的其他内容。一旦您正确安装了 ROCM 5.7,就可以使用 pip 安装它们。
pip install onnx ninja
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2

并运行安装后配置

python -m torch_ort.configure

安装 Optimum

您可以通过 pypi 安装 Optimum

pip install optimum

或者从源代码安装

pip install git+https://github.com/huggingface/optimum.git

此命令安装当前 main 开发版本的 Optimum,其中可能包含最新的开发成果(新功能、错误修复)。但是,main 版本可能不是很稳定。如果您遇到任何问题,请打开一个 issue,以便我们尽快修复。

ORTTrainer

ORTTrainer 类继承了 Transformers 的 Trainer。您可以通过将 transformers 的 Trainer 替换为 ORTTrainer 来轻松调整代码,从而利用 ONNX Runtime 提供的加速功能。以下是将 ORTTrainerTrainer 进行比较的用法示例

-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments

# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text-classification",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()

在 optimum 仓库中查看更详细的 示例脚本

ORTSeq2SeqTrainer

ORTSeq2SeqTrainer 类类似于 Transformers 的 Seq2SeqTrainer。您可以通过将 transformers 的 Seq2SeqTrainer 替换为 ORTSeq2SeqTrainer 来轻松调整代码,从而利用 ONNX Runtime 提供的加速功能。以下是将 ORTSeq2SeqTrainerSeq2SeqTrainer 进行比较的用法示例

-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments

# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text2text-generation",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()

在 optimum 仓库中查看更详细的 示例脚本

ORTTrainingArguments

ORTTrainingArguments 类继承了 Transformers 中的 TrainingArguments 类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments 替换为 ORTSeq2SeqTrainingArguments

-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments

-training_args = TrainingArguments(
+training_args =  ORTTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)

ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO stage 1 和 2)。您可以在 Optimum 仓库中找到一些 DeepSpeed 配置示例

ORTSeq2SeqTrainingArguments

ORTSeq2SeqTrainingArguments 类继承了 Transformers 中的 Seq2SeqTrainingArguments 类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments 替换为 ORTSeq2SeqTrainingArguments

-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments

-training_args = Seq2SeqTrainingArguments(
+training_args =  ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)

ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO stage 1 和 2)。您可以在 Optimum 仓库中找到一些 DeepSpeed 配置示例

ORTModule+StableDiffusion

Optimum 支持通过 ONNX Runtime 加速 Hugging Face Diffusers,请参阅此示例。启用 ONNX Runtime Training 所需的核心更改总结如下

import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel

+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer

unet = UNet2DConditionModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="unet",
    ...
)
text_encoder = CLIPTextModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="text_encoder",
    ...
)
vae = AutoencoderKL.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="vae",
    ...
)

optimizer = torch.optim.AdamW(
    unet.parameters(),
    ...
)

+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)

+optimizer = ORT_FP16_Optimizer(optimizer)

其他资源

如果您对 ORTTrainer 有任何问题或疑问,请在 Optimum Github 上提交 issue,或在 HuggingFace 社区论坛 上与我们讨论,谢谢!🤗

< > 在 GitHub 上更新