使用 torch.compile() 优化推理
本指南旨在为 torch.compile()
针对 🤗 Transformers 中的计算机视觉模型 带来的推理加速提供一个基准测试。
torch.compile 的优势
根据模型和 GPU 的不同,torch.compile()
在推理过程中可以带来高达 30% 的速度提升。要使用 torch.compile()
,只需安装任何高于 2.0 版本的 torch
。
编译模型需要时间,因此如果你只编译一次模型,而不是每次推理都编译,则会很有用。要编译你选择的任何计算机视觉模型,请在模型上调用 torch.compile()
,如下所示
from transformers import AutoModelForImageClassification
model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to("cuda")
+ model = torch.compile(model)
compile()
有多种编译模式,它们本质上在编译时间和推理开销方面有所不同。max-autotune
比 reduce-overhead
花费的时间更长,但推理速度更快。默认模式的编译速度最快,但与 reduce-overhead
相比,推理时间的效率不高。在本指南中,我们使用了默认模式。你可以了解更多有关它的信息 此处。
我们使用不同的计算机视觉模型、任务、硬件类型和批次大小在 torch
2.0.1 版本上对 torch.compile
进行了基准测试。
基准测试代码
以下你可以找到每个任务的基准测试代码。我们在推理之前预热 GPU,并使用相同图像,每次进行 300 次推理,取平均时间。
使用 ViT 进行图像分类
import torch
from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
model = torch.compile(model)
processed_input = processor(image, return_tensors='pt').to(device="cuda")
with torch.no_grad():
_ = model(**processed_input)
使用 DETR 进行目标检测
from transformers import AutoImageProcessor, AutoModelForObjectDetection
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
model = torch.compile(model)
texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
_ = model(**inputs)
使用 Segformer 进行图像分割
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to("cuda")
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
_ = model(**seg_inputs)
以下你可以找到我们基准测试的模型列表。
图像分类
- google/vit-base-patch16-224
- microsoft/beit-base-patch16-224-pt22k-ft22k
- facebook/convnext-large-224
- microsoft/resnet-50
图像分割
- nvidia/segformer-b0-finetuned-ade-512-512
- facebook/mask2former-swin-tiny-coco-panoptic
- facebook/maskformer-swin-base-ade
- google/deeplabv3_mobilenet_v2_1.0_513
目标检测
以下你可以找到使用和不使用 torch.compile()
的推理持续时间可视化以及不同硬件和批次大小下每个模型的百分比改进。
以下你可以找到使用和不使用 compile()
的每个模型的推理持续时间(以毫秒为单位)。请注意,OwlViT 在较大的批次大小下会导致 OOM。
A100(批次大小:1)
任务/模型 | torch 2.0 - 无编译 | torch 2.0 - 编译 |
---|---|---|
图像分类/ViT | 9.325 | 7.584 |
图像分割/Segformer | 11.759 | 10.500 |
目标检测/OwlViT | 24.978 | 18.420 |
图像分类/BeiT | 11.282 | 8.448 |
目标检测/DETR | 34.619 | 19.040 |
图像分类/ConvNeXT | 10.410 | 10.208 |
图像分类/ResNet | 6.531 | 4.124 |
图像分割/Mask2former | 60.188 | 49.117 |
图像分割/Maskformer | 75.764 | 59.487 |
图像分割/MobileNet | 8.583 | 3.974 |
目标检测/Resnet-101 | 36.276 | 18.197 |
目标检测/Conditional-DETR | 31.219 | 17.993 |
A100(批次大小:4)
任务/模型 | torch 2.0 - 无编译 | torch 2.0 - 编译 |
---|---|---|
图像分类/ViT | 14.832 | 14.499 |
图像分割/Segformer | 18.838 | 16.476 |
图像分类/BeiT | 13.205 | 13.048 |
目标检测/DETR | 48.657 | 32.418 |
图像分类/ConvNeXT | 22.940 | 21.631 |
图像分类/ResNet | 6.657 | 4.268 |
图像分割/Mask2former | 74.277 | 61.781 |
图像分割/Maskformer | 180.700 | 159.116 |
图像分割/MobileNet | 14.174 | 8.515 |
目标检测/Resnet-101 | 68.101 | 44.998 |
目标检测/Conditional-DETR | 56.470 | 35.552 |
A100(批次大小:16)
任务/模型 | torch 2.0 - 无编译 | torch 2.0 - 编译 |
---|---|---|
图像分类/ViT | 40.944 | 40.010 |
图像分割/Segformer | 37.005 | 31.144 |
图像分类/BeiT | 41.854 | 41.048 |
目标检测/DETR | 164.382 | 161.902 |
图像分类/ConvNeXT | 82.258 | 75.561 |
图像分类/ResNet | 7.018 | 5.024 |
图像分割/Mask2former | 178.945 | 154.814 |
图像分割/Maskformer | 638.570 | 579.826 |
图像分割/MobileNet | 51.693 | 30.310 |
目标检测/Resnet-101 | 232.887 | 155.021 |
目标检测/Conditional-DETR | 180.491 | 124.032 |