Optimum 文档

基于 ROCm 的 AMD GPU 加速推理

您正在查看 主分支 版本,需要从源代码安装。如果您想使用常规的 pip 安装,请查看最新的稳定版本 (v1.23.1).
Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

基于 ROCm 的 AMD GPU 加速推理

默认情况下,ONNX Runtime 在 CPU 设备上运行推理。但是,可以将支持的操作放在 AMD Instinct GPU 上,同时将任何不支持的操作留在 CPU 上。在大多数情况下,这允许将代价高昂的操作放在 GPU 上并显着加速推理。

我们的测试涉及 AMD Instinct GPU,有关特定 GPU 的兼容性,请参阅此处提供的官方 GPU 支持列表 此处

本指南将向您展示如何在 ONNX Runtime 支持的用于 AMD GPU 的 ROCMExecutionProvider 执行提供程序上运行推理。

安装

以下设置使用 ROCm 6.0 安装带有 ROCm 执行提供程序的 ONNX Runtime 支持。

1 ROCm 安装

请参考 ROCm 安装指南 安装 ROCm 6.0。

2 安装 onnxruntime-rocm

请使用提供的 Dockerfile 示例或从源代码进行本地安装,因为目前无法使用 pip wheel。

Docker 安装

docker build -f Dockerfile -t ort/rocm .

本地安装步骤

2.1 支持 ROCm 的 PyTorch

Optimum ONNX Runtime 集成依赖于 Transformers 的一些功能,这些功能需要 PyTorch。目前,我们建议使用针对 RoCm 6.0 编译的 Pytorch,可以按照 PyTorch 安装指南 安装。

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
# Use 'rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2' as the preferred base image when using Docker for PyTorch installation.
2.2 使用 ROCm 执行提供程序的 ONNX Runtime
# pre-requisites
pip install -U pip
pip install cmake onnx
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

# Install ONNXRuntime from source
git clone --single-branch --branch main --recursive https://github.com/Microsoft/onnxruntime onnxruntime
cd onnxruntime

./build.sh --config Release --build_wheel --allow_running_as_root --update --build --parallel --cmake_extra_defines CMAKE_HIP_ARCHITECTURES=gfx90a,gfx942 ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm
pip install build/Linux/Release/dist/*

注意:这些说明为 MI210/MI250/MI300 gpu 构建 ORT。要支持其他架构,请在构建命令中更新 CMAKE_HIP_ARCHITECTURES

为了避免 `onnxruntime` 和 `onnxruntime-rocm` 之间的冲突,在安装 `onnxruntime-rocm` 之前,请确保未安装包 `onnxruntime`,方法是运行 `pip uninstall onnxruntime`。

检查 ROCm 安装是否成功

在继续之前,运行以下示例代码以检查安装是否成功。

>>> from optimum.onnxruntime import ORTModelForSequenceClassification
>>> from transformers import AutoTokenizer

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
...   "philschmid/tiny-bert-sst2-distilled",
...   export=True,
...   provider="ROCMExecutionProvider",
... )

>>> tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
>>> inputs = tokenizer("expectations were low, actual enjoyment was high", return_tensors="pt", padding=True)

>>> outputs = ort_model(**inputs)
>>> assert ort_model.providers == ["ROCMExecutionProvider", "CPUExecutionProvider"]

如果此代码顺利运行,恭喜,安装成功!如果您遇到以下错误或类似错误,

ValueError: Asked to use ROCMExecutionProvider as an ONNX Runtime execution provider, but the available execution providers are ['CPUExecutionProvider'].

则ROCm或ONNX Runtime安装存在问题。

使用 ROCm 执行提供程序与 ORT 模型

对于 ORT 模型,使用非常简单。只需在 ORTModel.from_pretrained() 方法中指定 provider 参数即可。以下是一个示例

>>> from optimum.onnxruntime import ORTModelForSequenceClassification

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
...   "distilbert-base-uncased-finetuned-sst-2-english",
...   export=True,
...   provider="ROCMExecutionProvider",
... )

然后可以使用常见的 🤗 Transformers API 进行推理和评估,例如 pipelines。在使用 Transformers pipeline 时,请注意应设置 device 参数以在 GPU 上执行预处理和后处理,如下例所示

>>> from optimum.pipelines import pipeline
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

>>> pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
>>> result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
>>> print(result)
# printing: [{'label': 'POSITIVE', 'score': 0.9997727274894c714}]

此外,您可以传递会话选项 log_severity_level = 0(详细),以检查所有节点是否确实都放置在 ROCm 执行提供程序上。

>>> import onnxruntime

>>> session_options = onnxruntime.SessionOptions()
>>> session_options.log_severity_level = 0

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
...     "distilbert-base-uncased-finetuned-sst-2-english",
...     export=True,
...     provider="ROCMExecutionProvider",
...     session_options=session_options
... )

观察到的时间收益

即将推出!

< > GitHub 更新