Diffusers 文档

ParaAttention

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

ParaAttention

大型图像和视频生成模型,例如 FLUX.1-devHunyuanVideo,由于其庞大的规模,对实时应用和部署来说可能是一个推理挑战。

ParaAttention 是一个实现**上下文并行**和**首块缓存**的库,可以与其他技术(torch.compile,fp8 动态量化)结合使用,以加速推理。

本指南将向您展示如何在 NVIDIA L20 GPU 上将 ParaAttention 应用于 FLUX.1-dev 和 HunyuanVideo。除了 HunyuanVideo 为避免内存不足错误而进行的优化外,我们的基准测试没有应用任何优化。

我们的基准测试显示,FLUX.1-dev 能够在 28 步内生成 1024x1024 分辨率的图像,耗时 26.36 秒;HunyuanVideo 能够在 30 步内生成 129 帧 720p 分辨率的视频,耗时 3675.71 秒。

要通过上下文并行实现更快的推理,请尝试使用支持 NVLink 的 NVIDIA A100 或 H100 GPU(如果可用),尤其是在 GPU 数量较多的情况下。

首块缓存

缓存模型中 Transformer 块的输出并在接下来的推理步骤中重复使用它们可以降低计算成本并加快推理速度。

然而,很难决定何时重用缓存以确保生成图像或视频的质量。ParaAttention 直接使用**第一个 Transformer 块输出的残差差异**来近似模型输出之间的差异。当差异足够小时,前一个推理步骤的残差差异将被重用。换句话说,去噪步骤被跳过。

这在 FLUX.1-dev 和 HunyuanVideo 推理中实现了 2 倍的加速,同时保持了非常好的质量。

Cache in Diffusion Transformer
AdaCache 的工作原理,First Block Cache 是它的一个变体
FLUX-1.dev
HunyuanVideo

要在 FLUX.1-dev 上应用首块缓存,请按如下所示调用 `apply_cache_on_pipe`。0.08 是 FLUX 模型默认的残差差异值。

import time
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(pipe, residual_diff_threshold=0.08)

# Enable memory savings
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()

begin = time.time()
image = pipe(
    "A cat holding a sign that says hello world",
    num_inference_steps=28,
).images[0]
end = time.time()
print(f"Time: {end - begin:.2f}s")

print("Saving image to flux.png")
image.save("flux.png")
优化 原始 FBCache rdt=0.06 FBCache rdt=0.08 FBCache rdt=0.10 FBCache rdt=0.12
预览 Original FBCache rdt=0.06 FBCache rdt=0.08 FBCache rdt=0.10 FBCache rdt=0.12
墙钟时间 (秒) 26.36 21.83 17.01 16.00 13.78

与基准线相比,首块缓存将推理速度降低到 17.01 秒,即快 1.55 倍,同时几乎没有质量损失。

fp8 量化

带有动态量化的 fp8 进一步加快了推理速度并减少了内存使用。为了使用 8 位 NVIDIA Tensor Cores,必须量化激活和权重。

使用 `float8_weight_only` 和 `float8_dynamic_activation_float8_weight` 量化文本编码器和 Transformer 模型。

默认的量化方法是每张量量化,但如果您的 GPU 支持行级量化,您也可以尝试它以获得更好的精度。

使用以下命令安装 torchao

pip3 install -U torch torchao

torch.compile 使用 `mode="max-autotune-no-cudagraphs"` 或 `mode="max-autotune"` 选择最佳内核以获得性能。如果模型是第一次调用,编译可能需要很长时间,但一旦模型编译完成,这是值得的。

此示例仅量化 Transformer 模型,但您也可以量化文本编码器以进一步减少内存使用。

动态量化会显著改变模型输出的分布,因此您需要将 `residual_diff_threshold` 更改为更大的值才能使其生效。

FLUX-1.dev
HunyuanVideo
import time
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(
    pipe,
    residual_diff_threshold=0.12,  # Use a larger value to make the cache take effect
)

from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only

quantize_(pipe.text_encoder, float8_weight_only())
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
pipe.transformer = torch.compile(
   pipe.transformer, mode="max-autotune-no-cudagraphs",
)

# Enable memory savings
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()

for i in range(2):
    begin = time.time()
    image = pipe(
        "A cat holding a sign that says hello world",
        num_inference_steps=28,
    ).images[0]
    end = time.time()
    if i == 0:
        print(f"Warm up time: {end - begin:.2f}s")
    else:
        print(f"Time: {end - begin:.2f}s")

print("Saving image to flux.png")
image.save("flux.png")

fp8 动态量化和 torch.compile 将推理速度降低到 7.56 秒,比基线快 3.48 倍。

上下文并行

上下文并行可并行化推理并随多个 GPU 扩展。ParaAttention 的组合设计允许您将上下文并行与首块缓存和动态量化相结合。

请参阅 ParaAttention 仓库,了解如何使用多个 GPU 扩展推理的详细说明和示例。

如果推理过程需要持久且可服务,建议使用 torch.multiprocessing 编写自己的推理处理器。这可以消除启动进程以及加载和重新编译模型的开销。

FLUX-1.dev
HunyuanVideo

以下代码示例结合了首块缓存、fp8 动态量化、torch.compile 和上下文并行,以实现最快的推理速度。

import time
import torch
import torch.distributed as dist
from diffusers import FluxPipeline

dist.init_process_group()

torch.cuda.set_device(dist.get_rank())

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

mesh = init_context_parallel_mesh(
    pipe.device.type,
    max_ring_dim_size=2,
)
parallelize_pipe(
    pipe,
    mesh=mesh,
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(
    pipe,
    residual_diff_threshold=0.12,  # Use a larger value to make the cache take effect
)

from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only

quantize_(pipe.text_encoder, float8_weight_only())
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(
   pipe.transformer, mode="max-autotune-no-cudagraphs",
)

# Enable memory savings
# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())

for i in range(2):
    begin = time.time()
    image = pipe(
        "A cat holding a sign that says hello world",
        num_inference_steps=28,
        output_type="pil" if dist.get_rank() == 0 else "pt",
    ).images[0]
    end = time.time()
    if dist.get_rank() == 0:
        if i == 0:
            print(f"Warm up time: {end - begin:.2f}s")
        else:
            print(f"Time: {end - begin:.2f}s")

if dist.get_rank() == 0:
    print("Saving image to flux.png")
    image.save("flux.png")

dist.destroy_process_group()

保存到 `run_flux.py` 并使用 torchrun 启动它。

# Use --nproc_per_node to specify the number of GPUs
torchrun --nproc_per_node=2 run_flux.py

使用 2 个 NVIDIA L20 GPU,推理速度比基线降低到 8.20 秒,即快 3.21 倍。在 4 个 L20 GPU 上,推理速度为 3.90 秒,即快 6.75 倍。

基准测试

FLUX-1.dev
HunyuanVideo
GPU 类型 GPU 数量 优化 墙钟时间 (秒) 加速比
NVIDIA L20 1 基线 26.36 1.00倍
NVIDIA L20 1 FBCache (rdt=0.08) 17.01 1.55倍
NVIDIA L20 1 FP8 DQ 13.40 1.96倍
NVIDIA L20 1 FBCache (rdt=0.12) + FP8 DQ 7.56 3.48倍
NVIDIA L20 2 FBCache (rdt=0.12) + FP8 DQ + CP 4.92 5.35倍
NVIDIA L20 4 FBCache (rdt=0.12) + FP8 DQ + CP 3.90 6.75倍
< > 在 GitHub 上更新