Diffusers文档

PyTorch 2.0

Hugging Face's logo
加入Hugging Face社区

并获得增强文档体验的访问权限

开始使用

PyTorch 2.0

🤗 Diffusers 支持 PyTorch 2.0 的最新优化,包括

  1. 内存高效的注意力实现,缩放点积注意力,无需任何额外的依赖,例如 xFormers。
  2. torch.compile,即时编译器,为模型编译提供额外的性能提升。

这两种优化需要 PyTorch 2.0 或更高版本以及 🤗 Diffusers > 0.13.0。

pip install --upgrade torch diffusers

缩放点积注意力

torch.nn.functional.scaled_dot_product_attention (SDPA)是一种优化且内存高效的注意力(类似 xFormers),会根据模型输入和 GPU 类型自动启用其他优化。如果使用 PyTorch 2.0 和 🤗 Diffusers的最新版本,SDPA 默认启用,因此无需在代码中添加任何内容。

但是,如果想显式启用它,可以将 DiffusionPipeline 设置为使用 AttnProcessor2_0

  import torch
  from diffusers import DiffusionPipeline
+ from diffusers.models.attention_processor import AttnProcessor2_0

  pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+ pipe.unet.set_attn_processor(AttnProcessor2_0())

  prompt = "a photo of an astronaut riding a horse on mars"
  image = pipe(prompt).images[0]

SDPA 应该与 xFormers 的速度和内存效率相当;查看 性能基准 获取更多详细信息。

在某些情况下 - 例如使流水线更具确定性或转换为其他格式 - 使用原始注意力处理器(AttnProcessor)可能会有所帮助。要恢复到 AttnProcessor,请在流水线上调用 set_default_attn_processor() 函数

  import torch
  from diffusers import DiffusionPipeline

  pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+ pipe.unet.set_default_attn_processor()

  prompt = "a photo of an astronaut riding a horse on mars"
  image = pipe(prompt).images[0]

torch.compile

torch.compile 函数可以为您的 PyTorch 代码提供额外的加速。在 🤗 Diffusers 中,通常最好将 UNet 用 torch.compile 包装,因为它在管道中执行了大部分繁重的工作。

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images[0]

根据 GPU 类型,torch.compile 可以提供在 SDPA 之上的 额外加速,提升速度可达 5-300 倍!如果您使用的是更先进的 GPU 架构,如 Ampere(A100、3090)、Ada(4090)和 Hopper(H100),torch.compile 可以从这些 GPU 中榨取更多性能。

编译需要一些时间,因此最适合于您仅准备管道一次,然后执行多次相同类型的推理操作的情况。例如,对于不同图像大小的编译管道调用将再次触发编译,这可能会很昂贵。

有关 torch.compile 的更多信息及其不同选项,请参阅 torch_compile 教程。

加速文本到图像扩散模型推理 教程中了解更多 PyTorch 2.0 帮助优化您模型的其他方式。

基准测试

我们对 PyTorch 2.0 高效的注意力实现和 torch.compile 在不同 GPU 和批次大小上对五个最常用的管道进行了全面的基准测试。该代码在 🤗 Diffusers v0.17.0.dev0 上进行基准测试以优化 torch.compile 的使用(有关更多详细信息,请参阅 此处)。

展开下面的下拉菜单,查找每个管道的基准测试代码

稳定扩散文本到图像

from diffusers import DiffusionPipeline
import torch

path = "runwayml/stable-diffusion-v1-5"

run_compile = True  # Set True / False

pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
pipe = pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)

if run_compile:
    print("Run torch compile")
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

prompt = "ghibli style, a fantasy landscape with castles"

for _ in range(3):
    images = pipe(prompt=prompt).images

稳定扩散图像到图像

from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
import torch

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

init_image = load_image(url)
init_image = init_image.resize((512, 512))

path = "runwayml/stable-diffusion-v1-5"

run_compile = True  # Set True / False

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
pipe = pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)

if run_compile:
    print("Run torch compile")
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

prompt = "ghibli style, a fantasy landscape with castles"

for _ in range(3):
    image = pipe(prompt=prompt, image=init_image).images[0]

稳定扩散修复

from diffusers import StableDiffusionInpaintPipeline
from diffusers.utils import load_image
import torch

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).resize((512, 512))
mask_image = load_image(mask_url).resize((512, 512))

path = "runwayml/stable-diffusion-inpainting"

run_compile = True  # Set True / False

pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
pipe = pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)

if run_compile:
    print("Run torch compile")
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

prompt = "ghibli style, a fantasy landscape with castles"

for _ in range(3):
    image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]

控制网

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
import torch

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

init_image = load_image(url)
init_image = init_image.resize((512, 512))

path = "runwayml/stable-diffusion-v1-5"

run_compile = True  # Set True / False
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    path, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
)

pipe = pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)
pipe.controlnet.to(memory_format=torch.channels_last)

if run_compile:
    print("Run torch compile")
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)

prompt = "ghibli style, a fantasy landscape with castles"

for _ in range(3):
    image = pipe(prompt=prompt, image=init_image).images[0]

DeepFloyd IF 文本到图像 + 上采样

from diffusers import DiffusionPipeline
import torch

run_compile = True  # Set True / False

pipe_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
pipe_1.to("cuda")
pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
pipe_2.to("cuda")
pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, use_safetensors=True)
pipe_3.to("cuda")


pipe_1.unet.to(memory_format=torch.channels_last)
pipe_2.unet.to(memory_format=torch.channels_last)
pipe_3.unet.to(memory_format=torch.channels_last)

if run_compile:
    pipe_1.unet = torch.compile(pipe_1.unet, mode="reduce-overhead", fullgraph=True)
    pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True)
    pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True)

prompt = "the blue hulk"

prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)

for _ in range(3):
    image_1 = pipe_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
    image_2 = pipe_2(image=image_1, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
    image_3 = pipe_3(prompt=prompt, image=image_1, noise_level=100).images

以下图表显示了在 PyTorch 2.0 和启用 torch.compile 的情况下,五个 GPU 家族中的 StableDiffusionPipeline 的相对加速情况。后续图表的基准测试结果以 每秒迭代次数 计算。

t2i_speedup

为了更好地说明其他流程的加速情况,请参考以下 A100 在 PyTorch 2.0 和 torch.compile 下的图表。

a100_numbers

以下表格中,我们以 每秒迭代次数 为单位报告我们的发现。

A100(批量大小:1)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 21.66 23.13 44.03 49.74
SD - img2img 21.81 22.40 43.92 46.32
SD - inpaint 22.24 23.23 43.76 49.25
SD - controlnet 15.02 15.82 32.13 36.08
IF 20.21 /
13.84 /
24.00
20.12 /
13.70 /
24.03
97.34 /
27.23 /
111.66
SDXL - txt2img 8.64 9.9 - -

A100(批量大小:4)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 11.6 13.12 14.62 17.27
SD - img2img 11.47 13.06 14.66 17.25
SD - inpaint 11.67 13.31 14.88 17.48
SD - controlnet 8.28 9.38 10.51 12.41
IF 25.02 18.04 48.47
SDXL - txt2img 2.44 2.74 - -

A100 (批量大小:16)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 3.04 3.6 3.83 4.68
SD - img2img 2.98 3.58 3.83 4.67
SD - inpaint 3.04 3.66 3.9 4.76
SD - controlnet 2.15 2.58 2.74 3.35
IF 8.78 9.82 16.77
SDXL - txt2img 0.64 0.72 - -

V100 (批量大小:1)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 18.99 19.14 20.95 22.17
SD - img2img 18.56 19.18 20.95 22.11
SD - inpaint 19.14 19.06 21.08 22.20
SD - controlnet 13.48 13.93 15.18 15.88
IF 20.01 /
9.08 /
23.34
19.79 /
8.98 /
24.10
55.75 /
11.57 /
57.67

V100 (批量大小:4)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 5.96 5.89 6.83 6.86
SD - img2img 5.90 5.91 6.81 6.82
SD - inpaint 5.99 6.03 6.93 6.95
SD - controlnet 4.26 4.29 4.92 4.93
IF 15.41 14.76 22.95

V100 (批量大小:16)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 1.66 1.66 1.92 1.90
SD - img2img 1.65 1.65 1.91 1.89
SD - inpaint 1.69 1.69 1.95 1.93
SD - controlnet 1.19 1.19 预热后内存溢出 1.36
IF 5.43 5.29 7.06

T4(批量大小:1)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 6.9 6.95 7.3 7.56
SD - img2img 6.84 6.99 7.04 7.55
SD - inpaint 6.91 6.7 7.01 7.37
SD - controlnet 4.89 4.86 5.35 5.48
IF 17.42 /
2.47 /
18.52
16.96 /
2.45 /
18.69
24.63 /
2.47 /
23.39
SDXL - txt2img 1.15 1.16 - -

T4(批量大小:4)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 1.79 1.79 2.03 1.99
SD - img2img 1.77 1.77 2.05 2.04
SD - inpaint 1.81 1.82 2.09 2.09
SD - controlnet 1.34 1.27 1.47 1.46
IF 5.79 5.61 7.39
SDXL - txt2img 0.288 0.289 - -

T4(批量大小:16)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 2.34s 2.30s 在第2次迭代后内存耗尽 1.99s
SD - img2img 2.35s 2.31s 预热后内存溢出 2.00s
SD - inpaint 2.30s 2.26s 在第2次迭代后内存耗尽 1.95s
SD - controlnet 在第2次迭代后内存耗尽 在第2次迭代后内存耗尽 预热后内存溢出 预热后内存溢出
IF * 1.44 1.44 1.94
SDXL - txt2img 内存耗尽 内存耗尽 - -

RTX 3090(批量大小:1)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 22.56 22.84 23.84 25.69
SD - img2img 22.25 22.61 24.1 25.83
SD - inpaint 22.22 22.54 24.26 26.02
SD - controlnet 16.03 16.33 17.38 18.56
IF 27.08 /
9.07 /
31.23
26.75 /
8.92 /
31.47
68.08 /
11.16 /
65.29

RTX 3090 (批量大小:4)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 6.46 6.35 7.29 7.3
SD - img2img 6.33 6.27 7.31 7.26
SD - inpaint 6.47 6.4 7.44 7.39
SD - controlnet 4.59 4.54 5.27 5.26
IF 16.81 16.62 21.57

RTX 3090 (批量大小:16)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 1.7 1.69 1.93 1.91
SD - img2img 1.68 1.67 1.93 1.9
SD - inpaint 1.72 1.71 1.97 1.94
SD - controlnet 1.23 1.22 1.4 1.38
IF 5.01 5.00 6.33

RTX 4090 (批量大小:1)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 40.5 41.89 44.65 49.81
SD - img2img 40.39 41.95 44.46 49.8
SD - inpaint 40.51 41.88 44.58 49.72
SD - controlnet 29.27 30.29 32.26 36.03
IF 69.71 /
18.78 /
85.49
69.13 /
18.80 /
85.56
124.60 /
26.37 /
138.79
SDXL - txt2img 6.8 8.18 - -

RTX 4090 (批量大小:4)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 12.62 12.84 15.32 15.59
SD - img2img 12.61 12,.79 15.35 15.66
SD - inpaint 12.65 12.81 15.3 15.58
SD - controlnet 9.1 9.25 11.03 11.22
IF 31.88 31.14 43.92
SDXL - txt2img 2.19 2.35 - -

RTX 4090 (批量大小:16)

流程 torch 2.0 -
不编译
torch 夜间版本 -
不编译
torch 2.0 -
编译
torch 夜间版本 -
编译
SD - txt2img 3.17 3.2 3.84 3.85
SD - img2img 3.16 3.2 3.84 3.85
SD - inpaint 3.17 3.2 3.85 3.85
SD - controlnet 2.23 2.3 2.7 2.75
IF 9.26 9.2 13.31
SDXL - txt2img 0.52 0.53 - -

注释

  • 有关用于进行基准测试的环境的更多细节,请查看以下 PR
  • 在 DeepFloyd IF 流程中,对于批大小 > 1 的情况,我们仅在文本到图像生成的第一个 IF 流程中使用批大小 > 1,而不是用于放大。这意味着两个放大流程接收的批大小为 1。

感谢 PyTorch 团队的 Chillee(Horace He)在改进 Diffusers 中对 torch.compile() 的支持方面的支持。

< > GitHub 更新