Diffusers 文档

减少内存使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

减少内存使用

使用扩散模型的障碍之一是需要大量的内存。为了克服这一挑战,您可以使用几种减少内存的技术,即使在免费层级或消费级 GPU 上也能运行一些最大的模型。其中一些技术甚至可以结合使用,以进一步减少内存使用。

在许多情况下,针对内存或速度进行优化也会提高另一方面的性能,因此您应该尽可能尝试同时进行优化。本指南侧重于最大限度地减少内存使用,但您也可以了解更多关于如何加速推理的信息。

以下结果是通过从提示词“a photo of an astronaut riding a horse on mars”生成单张 512x512 图像,并在 Nvidia Titan RTX 上使用 50 个 DDIM 步骤获得的,展示了您由于内存消耗减少而可以预期的速度提升。

延迟 加速
原始 9.50秒 x1
fp16 3.61秒 x2.63
channels last 3.30秒 x2.88
traced UNet 3.21秒 x2.96
memory-efficient attention 2.63秒 x3.61

Sliced VAE

Sliced VAE 允许解码 VRAM 有限的大批量图像或 32 张或更多图像的批次,方法是一次解码一个图像的潜在批次。如果您安装了 xFormers,您可能需要将此与enable_xformers_memory_efficient_attention()结合使用,以进一步减少内存使用。

要使用 sliced VAE,请在推理前在您的 pipeline 上调用enable_vae_slicing()

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
#pipe.enable_xformers_memory_efficient_attention()
images = pipe([prompt] * 32).images

您可能会看到在多图像批次上 VAE 解码的性能略有提升,并且在单图像批次上应该没有性能影响。

Tiled VAE

Tiled VAE 处理还允许在 VRAM 有限的情况下处理大型图像(例如,在 8GB VRAM 上生成 4k 图像),方法是将图像分割成重叠的瓦片,解码瓦片,然后将输出混合在一起以合成最终图像。如果您安装了 xFormers,您还应该将 tiled VAE 与enable_xformers_memory_efficient_attention()一起使用,以进一步减少内存使用。

要使用 tiled VAE 处理,请在推理前在您的 pipeline 上调用enable_vae_tiling()

import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
#pipe.enable_xformers_memory_efficient_attention()

image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]

输出图像具有一些瓦片到瓦片的色调变化,因为瓦片是单独解码的,但您不应看到瓦片之间有任何明显的尖锐接缝。对于 512x512 或更小的图像,平铺功能将被关闭。

CPU 卸载

将权重卸载到 CPU,仅在执行前向传递时将其加载到 GPU 上,也可以节省内存。通常,此技术可以将内存消耗减少到 3GB 以下。

要执行 CPU 卸载,请调用enable_sequential_cpu_offload()

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
image = pipe(prompt).images[0]

CPU 卸载作用于子模块而不是整个模型。这是最大限度减少内存消耗的最佳方法,但由于扩散过程的迭代性质,推理速度会慢得多。pipeline 的 UNet 组件会运行多次(多达 num_inference_steps 次);每次运行时,不同的 UNet 子模块会根据需要依次加载和卸载,从而导致大量的内存传输。

如果您想优化速度,请考虑使用模型卸载,因为它速度更快。权衡之处在于您的内存节省不会那么大。

当使用enable_sequential_cpu_offload()时,不要预先将 pipeline 移动到 CUDA,否则内存消耗的增益将仅是最小的(有关更多信息,请参阅此问题)。

enable_sequential_cpu_offload()是一个有状态的操作,会在模型上安装 hooks。

模型卸载

模型卸载需要 🤗 Accelerate 版本 0.17.0 或更高版本。

顺序 CPU 卸载可以保留大量内存,但会降低推理速度,因为子模块会根据需要移动到 GPU,并在新模块运行时立即返回到 CPU。

全模型卸载是一种替代方案,它将整个模型移动到 GPU,而不是处理每个模型的组成子模块。与将 pipeline 移动到 cuda 相比,对推理时间的影响可以忽略不计,并且仍然可以节省一些内存。

在模型卸载期间,pipeline 的主要组件(通常是文本编码器、UNet 和 VAE)中只有一个放置在 GPU 上,而其他组件则在 CPU 上等待。像 UNet 这样运行多次的组件会一直留在 GPU 上,直到不再需要它们为止。

通过在 pipeline 上调用enable_model_cpu_offload()来启用模型卸载

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
)

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
image = pipe(prompt).images[0]

为了在模型被调用后正确卸载模型,需要运行整个 pipeline,并且模型按照 pipeline 的预期顺序被调用。如果在安装 hooks 后在 pipeline 上下文之外重用模型,请谨慎操作。有关更多信息,请参阅Removing Hooks

enable_model_cpu_offload()是一个有状态的操作,会在模型上安装 hooks,并在 pipeline 上安装状态。

Channels-last 内存格式

channels-last 内存格式是 NCHW 张量在内存中排序以保留维度顺序的另一种方式。channels-last 张量的排序方式使通道成为最密集的维度(逐像素存储图像)。由于并非所有运算符当前都支持 channels-last 格式,因此可能会导致性能下降,但您仍然应该尝试看看它是否适用于您的模型。

例如,要将 pipeline 的 UNet 设置为使用 channels-last 格式

print(pipe.unet.conv_out.state_dict()["weight"].stride())  # (2880, 9, 3, 1)
pipe.unet.to(memory_format=torch.channels_last)  # in-place operation
print(
    pipe.unet.conv_out.state_dict()["weight"].stride()
)  # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works

Tracing

Tracing 通过模型运行示例输入张量,并捕获在该输入通过模型的各个层时对其执行的操作。返回的可执行文件或 ScriptFunction 通过即时编译进行优化。

要 tracing UNet

import time
import torch
from diffusers import StableDiffusionPipeline
import functools

# torch disable grad
torch.set_grad_enabled(False)

# set variables
n_experiments = 2
unet_runs_per_experiment = 50


# load inputs
def generate_inputs():
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
    return sample, timestep, encoder_hidden_states


pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last)  # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False)  # set return_dict=False as default

# warmup
for _ in range(3):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet(*inputs)

# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")


# warmup and optimize graph
for _ in range(5):
    with torch.inference_mode():
        inputs = generate_inputs()
        orig_output = unet_traced(*inputs)


# benchmarking
with torch.inference_mode():
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet_traced(*inputs)
        torch.cuda.synchronize()
        print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
    for _ in range(n_experiments):
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(unet_runs_per_experiment):
            orig_output = unet(*inputs)
        torch.cuda.synchronize()
        print(f"unet inference took {time.time() - start_time:.2f} seconds")

# save the model
unet_traced.save("unet_traced.pt")

将 pipeline 的 unet 属性替换为 traced 模型

from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass


@dataclass
class UNet2DConditionOutput:
    sample: torch.Tensor


pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")


# del pipe.unet
class TracedUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.in_channels = pipe.unet.config.in_channels
        self.device = pipe.unet.device

    def forward(self, latent_model_input, t, encoder_hidden_states):
        sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
        return UNet2DConditionOutput(sample=sample)


pipe.unet = TracedUNet()

with torch.inference_mode():
    image = pipe([prompt] * 1, num_inference_steps=50).images[0]

Memory-efficient attention

最近关于优化注意力模块带宽的工作已经产生了巨大的加速效果,并减少了 GPU 内存使用。最新的内存高效注意力机制是 Flash Attention (你可以查看 HazyResearch/flash-attention 上的原始代码)。

如果您已安装 PyTorch >= 2.0,则在启用 xformers 时,不应期望在推理速度上有所提升。

要使用 Flash Attention,请安装以下组件

然后对 pipeline 调用 enable_xformers_memory_efficient_attention()

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

pipe.enable_xformers_memory_efficient_attention()

with torch.inference_mode():
    sample = pipe("a small cat")

# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()

使用 xformers 时的迭代速度应与 此处 描述的 PyTorch 2.0 的迭代速度相匹配。

< > 在 GitHub 上更新