Optimum 文档
如何使用 ONNX Runtime 加速训练
并获得增强的文档体验
开始使用
如何使用 ONNX Runtime 加速训练
Optimum 通过 ORTTrainer
API 集成了 ONNX Runtime Training,该 API 扩展了 Transformers 中的 Trainer
。通过此扩展,与 eager 模式下的 PyTorch 相比,许多流行的 Hugging Face 模型的训练时间可以减少 35% 以上。
ORTTrainer
和 ORTSeq2SeqTrainer
API 可以轻松地将 ONNX Runtime (ORT) 与 Trainer
中的其他功能结合使用。它包含功能完整的训练循环和评估循环,并支持超参数搜索、混合精度训练以及使用多个 NVIDIA 和 AMD GPU 的分布式训练。借助 ONNX Runtime 后端,ORTTrainer
和 ORTSeq2SeqTrainer
可以利用:
- 计算图优化:常量折叠、节点消除、节点融合
- 高效的内存规划
- 内核优化
- ORT 融合 Adam 优化器:将应用于模型所有参数的逐元素更新批处理到一个或几个内核启动中
- 更高效的 FP16 优化器:消除大量设备到主机的内存复制
- 混合精度训练
尝试一下,在 🤗 Transformers 中训练模型时,实现更低的延迟、更高的吞吐量和更大的最大批处理大小!
性能
下图显示了使用 Optimum 的 Hugging Face 模型在使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 进行训练时,令人印象深刻的加速效果,从 39% 到 130% 不等。性能测量是在选定的 Hugging Face 模型上完成的,PyTorch 作为基线运行,仅 ONNX Runtime 用于训练作为第二次运行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作为最终运行,显示了最大增益。基线 PyTorch 运行使用的优化器是 AdamW 优化器,ORT Training 运行使用融合 Adam 优化器(在 ORTTrainingArguments
中可用)。运行是在具有 8 个 GPU 的单个 Nvidia A100 节点上执行的。

这些运行使用的版本信息如下
PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2
开始设置环境
要使用 ONNX Runtime 进行训练,您需要一台至少配备一个 NVIDIA 或 AMD GPU 的机器。
要使用 ORTTrainer
或 ORTSeq2SeqTrainer
,您需要安装 ONNX Runtime Training 模块和 Optimum。
安装 ONNX Runtime
要设置环境,我们强烈建议您使用 Docker 安装依赖项,以确保版本正确且配置良好。您可以在此处找到包含各种组合的 dockerfiles。
NVIDIA GPU 设置
下面我们以安装 onnxruntime-training 1.14.0
为例
- 如果您想通过 Dockerfile 安装
onnxruntime-training 1.14.0
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
pip install onnx ninja pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
AMD GPU 设置
下面我们以安装 onnxruntime-training
nightly 版本为例
- 如果您想通过 Dockerfile 安装
onnxruntime-training
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
- 如果您想在本地 Python 环境中安装依赖项之外的其他内容。一旦您正确安装了 ROCM 5.7,就可以使用 pip 安装它们。
pip install onnx ninja pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
安装 Optimum
您可以通过 pypi 安装 Optimum
pip install optimum
或者从源代码安装
pip install git+https://github.com/huggingface/optimum.git
此命令安装当前 main 开发版本的 Optimum,其中可能包含最新的开发成果(新功能、错误修复)。但是,main 版本可能不是很稳定。如果您遇到任何问题,请打开一个 issue,以便我们尽快修复。
ORTTrainer
ORTTrainer
类继承了 Transformers 的 Trainer
。您可以通过将 transformers 的 Trainer
替换为 ORTTrainer
来轻松调整代码,从而利用 ONNX Runtime 提供的加速功能。以下是将 ORTTrainer
与 Trainer
进行比较的用法示例
-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments
# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text-classification",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
在 optimum 仓库中查看更详细的 示例脚本。
ORTSeq2SeqTrainer
ORTSeq2SeqTrainer
类类似于 Transformers 的 Seq2SeqTrainer
。您可以通过将 transformers 的 Seq2SeqTrainer
替换为 ORTSeq2SeqTrainer
来轻松调整代码,从而利用 ONNX Runtime 提供的加速功能。以下是将 ORTSeq2SeqTrainer
与 Seq2SeqTrainer
进行比较的用法示例
-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments
# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text2text-generation",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
在 optimum 仓库中查看更详细的 示例脚本。
ORTTrainingArguments
ORTTrainingArguments
类继承了 Transformers 中的 TrainingArguments
类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO stage 1 和 2)。您可以在 Optimum 仓库中找到一些 DeepSpeed 配置示例。
ORTSeq2SeqTrainingArguments
ORTSeq2SeqTrainingArguments
类继承了 Transformers 中的 Seq2SeqTrainingArguments
类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO stage 1 和 2)。您可以在 Optimum 仓库中找到一些 DeepSpeed 配置示例。
ORTModule+StableDiffusion
Optimum 支持通过 ONNX Runtime 加速 Hugging Face Diffusers,请参阅此示例。启用 ONNX Runtime Training 所需的核心更改总结如下
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
...
)
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="text_encoder",
...
)
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
...
)
optimizer = torch.optim.AdamW(
unet.parameters(),
...
)
+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)
+optimizer = ORT_FP16_Optimizer(optimizer)
其他资源
如果您对 ORTTrainer
有任何问题或疑问,请在 Optimum Github 上提交 issue,或在 HuggingFace 社区论坛 上与我们讨论,谢谢!🤗