Optimum 文档
如何使用 ONNX Runtime 加速训练
并获得增强的文档体验
开始使用
如何使用 ONNX Runtime 加速训练
Optimum 通过 ORTTrainer
API 集成了 ONNX Runtime 训练,该 API 扩展了 Transformers 中的 Trainer
。通过此扩展,与 Eager 模式下的 PyTorch 相比,许多流行的 Hugging Face 模型的训练时间可以减少 35% 以上。
ORTTrainer
和 ORTSeq2SeqTrainer
API 使得将 ONNX Runtime (ORT) 与 Trainer
中的其他功能轻松组合。它包含功能完整的训练循环和评估循环,并支持超参数搜索、混合精度训练以及使用多个 NVIDIA 和 AMD GPU 的分布式训练。借助 ONNX Runtime 后端,ORTTrainer
和 ORTSeq2SeqTrainer
可利用以下优势:
- 计算图优化:常量折叠、节点消除、节点融合
- 高效内存规划
- 内核优化
- ORT 融合 Adam 优化器:将应用于所有模型参数的逐元素更新批处理为一个或几个内核启动
- 更高效的 FP16 优化器:消除了大量的设备到主机内存复制
- 混合精度训练
尝试一下,在 🤗 Transformers 中训练模型时实现更低延迟、更高吞吐量和更大的最大批大小!
性能
下表显示了当使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 进行训练时,Optimum 对 Hugging Face 模型实现了令人印象深刻的加速,从 39% 到 130%。性能测量是在选定的 Hugging Face 模型上进行的,其中 PyTorch 作为基线运行,仅 ONNX Runtime 进行训练作为第二次运行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作为最终运行,显示出最大收益。基线 PyTorch 运行使用的优化器是 AdamW 优化器,而 ORT 训练运行使用融合 Adam 优化器(在 ORTTrainingArguments
中可用)。运行在一台具有 8 个 GPU 的 Nvidia A100 节点上执行。

这些运行中使用的版本信息如下:
PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2
开始设置环境
要使用 ONNX Runtime 进行训练,您需要一台至少配备一个 NVIDIA 或 AMD GPU 的机器。
要使用 ORTTrainer
或 ORTSeq2SeqTrainer
,您需要安装 ONNX Runtime Training 模块和 Optimum。
安装 ONNX Runtime
为了设置环境,我们**强烈建议**您使用 Docker 安装依赖项,以确保版本正确且配置良好。您可以在此处找到各种组合的 Dockerfile。
NVIDIA GPU 的设置
以下我们以安装 onnxruntime-training 1.14.0
为例
- 如果您想通过 Dockerfile 安装
onnxruntime-training 1.14.0
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
pip install onnx ninja pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
AMD GPU 的设置
以下我们以安装 onnxruntime-training
nightly 为例
- 如果您想通过 Dockerfile 安装
onnxruntime-training
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
- 如果您想在本地 Python 环境中安装其他依赖项。您可以在成功安装 ROCM 5.7 后,通过 pip 安装它们。
pip install onnx ninja pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
安装 Optimum
您可以通过 pypi 安装 Optimum
pip install optimum
或从源代码安装
pip install git+https://github.com/huggingface/optimum.git
此命令安装 Optimum 的当前主开发版本,其中可能包含最新的开发(新功能、错误修复)。但是,主版本可能不太稳定。如果遇到任何问题,请提出 问题,以便我们尽快解决。
ORTTrainer
ORTTrainer
类继承了 Transformers 的 Trainer
。您可以通过用 ORTTrainer
替换 transformers 的 Trainer
来轻松调整代码,以利用 ONNX Runtime 带来的加速。以下是与 Trainer
相比如何使用 ORTTrainer
的示例:
-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments
# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text-classification",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
请查看 Optimum 仓库中更详细的示例脚本。
ORTSeq2SeqTrainer
ORTSeq2SeqTrainer
类类似于 Transformers 的 Seq2SeqTrainer
。您可以通过用 ORTSeq2SeqTrainer
替换 transformers 的 Seq2SeqTrainer
来轻松调整代码,以利用 ONNX Runtime 带来的加速。以下是与 Seq2SeqTrainer
相比如何使用 ORTSeq2SeqTrainer
的示例
-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments
# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text2text-generation",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
请查看 Optimum 仓库中更详细的示例脚本。
ORTTrainingArguments
ORTTrainingArguments
类继承了 Transformers 中的 TrainingArguments
类。除了 Transformers 中实现的优化器外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
。
-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些 DeepSpeed 配置示例。
ORTSeq2SeqTrainingArguments
ORTSeq2SeqTrainingArguments
类继承了 Transformers 中的 Seq2SeqTrainingArguments
类。除了 Transformers 中实现的优化器外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
。
-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些 DeepSpeed 配置示例。
ORTModule+StableDiffusion
Optimum 支持在此示例中通过 ONNX Runtime 加速 Hugging Face Diffusers。启用 ONNX Runtime 训练所需的核心更改总结如下:
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
...
)
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="text_encoder",
...
)
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
...
)
optimizer = torch.optim.AdamW(
unet.parameters(),
...
)
+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)
+optimizer = ORT_FP16_Optimizer(optimizer)
其他资源
如果您对 ORTTrainer
有任何问题,请在 Optimum Github 上提出问题或在 HuggingFace 的社区论坛上与我们讨论,祝好 🤗!