如何使用 ONNX Runtime 加速训练
Optimum 通过一个 ORTTrainer
API 集成了 ONNX Runtime 训练,该 API 扩展了 Transformers 中的 Trainer
。使用此扩展,与 PyTorch 在急切模式下相比,许多流行的 Hugging Face 模型的训练时间可以减少 35% 以上。
ORTTrainer
和 ORTSeq2SeqTrainer
API 使得将 ONNX Runtime (ORT) 与 Trainer
中的其他功能组合起来变得容易。它包含功能完整的训练循环和评估循环,并支持超参数搜索、混合精度训练以及使用多个 NVIDIA 和 AMD GPU 进行分布式训练。借助 ONNX Runtime 后端,ORTTrainer
和 ORTSeq2SeqTrainer
利用了
- 计算图优化:常量折叠、节点消除、节点融合
- 高效的内存规划
- 内核优化
- ORT 融合的 Adam 优化器:将应用于模型所有参数的逐元素更新批处理到一个或几个内核启动中
- 更有效的 FP16 优化器:消除了大量设备到主机内存的复制
- 混合精度训练
试试看,在 🤗 Transformers 中训练模型时实现 更低的延迟、更高的吞吐量和更大的最大批次大小!
性能
下图显示了 Hugging Face 模型使用 Optimum 时在 使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 进行训练时,从 39% 到 130% 的显著加速。性能测量是在选定的 Hugging Face 模型上完成的,其中 PyTorch 作为基准运行,仅 ONNX Runtime 用于训练作为第二次运行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作为最终运行,显示最大增益。基准 PyTorch 运行使用的优化器是 AdamW 优化器,而 ORT 训练运行使用的是 Fused Adam 优化器(在 ORTTrainingArguments
中可用)。这些运行在一个具有 8 个 GPU 的单个 Nvidia A100 节点上执行。
这些运行使用的版本信息如下
PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2
从设置环境开始
要使用 ONNX Runtime 进行训练,您需要一台至少配备一个 NVIDIA 或 AMD GPU 的机器。
要使用 ORTTrainer
或 ORTSeq2SeqTrainer
,您需要安装 ONNX Runtime 训练模块和 Optimum。
安装 ONNX Runtime
要设置环境,我们 强烈建议 您使用 Docker 安装依赖项,以确保版本正确且配置良好。您可以在 此处 找到包含各种组合的 Dockerfile。
NVIDIA GPU 设置
下面我们以安装 onnxruntime-training 1.14.0
为例
- 如果您想通过Dockerfile安装
onnxruntime-training 1.14.0
。
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
pip install onnx ninja pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
AMD GPU 设置
下面我们将以安装onnxruntime-training
nightly 版本为例。
- 如果您想通过Dockerfile安装
onnxruntime-training
。
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
- 如果您想在本地 Python 环境中安装除上述之外的依赖项,您可以在安装好ROCM 5.7 后使用 pip 安装它们。
pip install onnx ninja pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html pip install torch-ort pip install --upgrade protobuf==3.20.2
并运行安装后配置
python -m torch_ort.configure
安装 Optimum
您可以通过 pypi 安装 Optimum
pip install optimum
或从源代码安装
pip install git+https://github.com/huggingface/optimum.git
此命令安装 Optimum 的当前主开发版本,其中可能包含最新的开发成果(新功能、错误修复)。但是,主版本可能不稳定。如果您遇到任何问题,请打开一个问题,以便我们尽快解决。
ORTTrainer
ORTTrainer
类继承了 Transformers 的Trainer
。您可以轻松地通过将 Transformers 的 Trainer
替换为 ORTTrainer
来调整代码,以利用 ONNX Runtime 加速。以下是如何使用 ORTTrainer
与 Trainer
相比的示例。
-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments
# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text-classification",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
在 optimum 存储库中查看更多详细的示例脚本。
ORTSeq2SeqTrainer
ORTSeq2SeqTrainer
类类似于 Transformers 的Seq2SeqTrainer
。您可以轻松地通过将 Transformers 的 Seq2SeqTrainer
替换为 ORTSeq2SeqTrainer
来调整代码,以利用 ONNX Runtime 加速。以下是如何使用 ORTSeq2SeqTrainer
与 Seq2SeqTrainer
相比的示例。
-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments
# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text2text-generation",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
在 optimum 存储库中查看更多详细的示例脚本。
ORTTrainingArguments
ORTTrainingArguments
类继承了 Transformers 中的TrainingArguments
类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
。
-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些DeepSpeed 配置示例。
ORTSeq2SeqTrainingArguments
ORTSeq2SeqTrainingArguments
类继承了 Transformers 中的Seq2SeqTrainingArguments
类。除了 Transformers 中实现的优化器之外,它还允许您使用 ONNX Runtime 中实现的优化器。将 Seq2SeqTrainingArguments
替换为 ORTSeq2SeqTrainingArguments
。
-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支持 DeepSpeed(目前仅支持 ZeRO 阶段 1 和 2)。您可以在 Optimum 存储库中找到一些DeepSpeed 配置示例。
ORTModule+StableDiffusion
Optimum 支持使用 ONNX Runtime 在此示例中加速 Hugging Face Diffusers。启用 ONNX Runtime 训练所需的核心更改总结如下
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
...
)
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="text_encoder",
...
)
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
...
)
optimizer = torch.optim.AdamW(
unet.parameters(),
...
)
+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)
+optimizer = ORT_FP16_Optimizer(optimizer)
其他资源
如果您在使用 ORTTrainer
时遇到任何问题或疑问,请在Optimum Github 上提交问题,或在HuggingFace 社区论坛 上与我们讨论,🤗 感谢您的支持!