使用Diffusers和PEFT为Flux实现快速LoRA推理
LoRA适配器为各种大小的模型提供了极大的定制化能力。在图像生成方面,它们可以赋予模型不同的风格、不同的角色等更多功能。有时,它们还可以用于减少推理延迟。因此,它们的重要性是至关重要的,尤其是在定制和微调模型时。
在这篇文章中,我们选择了Flux.1-Dev模型进行文本到图像生成,因为它广受欢迎且应用广泛。我们探讨了如何在使用LoRA时优化其推理速度(约2.3倍)。根据Hugging Face Hub平台上的报告,该模型已训练了超过3万个适配器。因此,它对社区的重要性是巨大的。
请注意,尽管我们演示了Flux的加速效果,但我们相信我们的方法足够通用,可以应用于其他模型。
如果您迫不及待想开始编码,请查看随附的代码库。
目录
优化LoRA推理的障碍
在提供LoRA服务时,通常会进行热插拔(即插拔不同的LoRA)。LoRA会改变基础模型的架构。此外,LoRA之间也可能不同——每个LoRA可能具有不同的秩,并针对不同的层进行适配。为了应对LoRA的这些动态特性,我们必须采取必要的措施来确保我们应用的优化是稳健的。
例如,我们可以在加载了特定LoRA的模型上应用torch.compile
,以提高推理延迟。但是,一旦我们将LoRA替换为另一个(可能具有不同配置的)LoRA,就会遇到重新编译的问题,导致推理速度下降。
还可以将LoRA参数融合到基础模型参数中,运行编译,然后在加载新参数时解除LoRA参数的融合。然而,这种方法在每次运行推理时,由于潜在的架构级更改,仍然会遇到重新编译的问题。
我们的优化方法考虑了上述情况,以尽可能地切合实际。以下是我们优化方法的核心组成部分:
- Flash Attention 3 (FA3)
torch.compile
- TorchAO的FP8量化
- 支持热插拔
请注意,在上述组件中,FP8量化是无损的,但通常能提供最强大的速度-内存权衡。尽管我们主要使用NVIDIA GPU测试了该方法,但它也应该适用于AMD GPU。
优化方法
在我们之前的博客文章(文章1和文章2)中,我们已经讨论了使用我们优化方法前三个组件的好处。逐一应用它们只需几行代码。
from diffusers import DiffusionPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from utils.fa3_processor import FlashFluxAttnProcessor3_0
import torch
# quantize the Flux transformer with FP8
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
quantization_config=PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
)
).to("cuda")
# use Flash-attention 3
pipe.transformer.set_attn_processor(FlashFluxAttnProcessor3_0())
# use torch.compile()
pipe.transformer.compile(fullgraph=True, mode="max-autotune")
# perform inference
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 28,
"max_sequence_length": 512,
}
# first time will be slower, subsequent runs will be faster
image = pipe(**pipe_kwargs).images[0]
FA3处理器来自此处。
当我们将LoRA热插拔到已编译的扩散Transformer(`pipe.transformer`)中而不触发重新编译时,问题开始浮现。
通常,加载和卸载LoRA会需要重新编译,这会抵消编译带来的任何速度优势。幸运的是,有一种方法可以避免重新编译。通过传递`hotswap=True`,Diffusers将保持模型架构不变,只交换LoRA适配器本身的权重,这不需要重新编译。
pipe.enable_lora_hotswap(target_rank=max_rank)
pipe.load_lora_weights(<lora-adapter-name1>)
# compile *after* loading the first LoRA
pipe.transformer.compile(mode="max-autotune", fullgraph=True)
image = pipe(**pipe_kwargs).images[0]
# from this point on, load new LoRAs with `hotswap=True`
pipe.load_lora_weights(<lora-adapter-name2>, hotswap=True)
image = pipe(**pipe_kwargs).images[0]
(提醒一下,第一次调用`pipe`会很慢,因为`torch.compile`是即时编译器。然而,随后的调用应该会显著加快。)
这通常允许在不重新编译的情况下交换 LoRA,但存在一些限制:
- 我们需要提前提供所有 LoRA 适配器中的最大秩。因此,如果我们有一个秩为 16 的适配器,另一个秩为 32 的适配器,我们需要传递 `max_rank=32`。
- 热插拔的LoRA适配器只能针对第一个LoRA所针对的相同层或其子集。
- 目前尚不支持文本编码器目标化。
有关Diffusers中热插拔及其限制的更多信息,请访问文档中的热插拔部分。
当我们查看不使用编译进行热插拔时的推理延迟时,这种工作流程的好处变得显而易见。
选项 | 时间 (秒) ⬇️ | 加速 (对比基线) ⬆️ | 备注 |
---|---|---|---|
基准 | 7.8910 | – | 基线 |
已优化 | 3.5464 | 2.23倍 | 热插拔 + 编译,无重新编译卡顿(默认开启FP8) |
无FP8 | 4.3520 | 1.81倍 | 与“已优化”相同,但禁用FP8量化 |
无FA3 | 4.3020 | 1.84倍 | 禁用 FA3 (flash‑attention v3) |
基线 + 编译 | 5.0920 | 1.55倍 | 编译开启,但受间歇性重新编译停顿影响 |
无FA3_FP8 | 5.0850 | 1.55倍 | 禁用 FA3 和 FP8 |
无编译_FP8 | 7.5190 | 1.05倍 | 禁用 FP8 量化和编译 |
无编译 | 10.4340 | 0.76倍 | 禁用编译:最慢的设置 |
主要收获:
- “常规+编译”选项比常规选项提供了不错的加速,但它会引发重新编译问题,从而增加总执行时间。在我们的基准测试中,我们没有给出编译时间。
- 通过热插拔消除重新编译问题(也称为“优化”选项)时,我们实现了最高的加速。
- 在“优化”选项中,FP8量化已启用,这可能导致质量损失。即使不使用FP8,我们也能获得不错的加速(“无FP8”选项)。
- 为了演示目的,我们使用一个包含两个LoRA的池进行编译热插拔。有关完整代码,请参阅随附的代码库。
我们迄今讨论的优化方法假定能够访问像H100这样的强大GPU。然而,当我们受限于使用RTX 4090等消费级GPU时,我们能做些什么呢?让我们一探究竟。
在消费级GPU上优化LoRA推理
Flux.1-Dev(不带任何LoRA)使用Bfloat16数据类型运行,占用约33GB内存。根据LoRA模块的大小,如果不进行任何优化,内存占用还会进一步增加。许多消费级GPU,如RTX 4090,只有24GB内存。在本节的其余部分,我们将RTX 4090机器作为我们的测试平台。
首先,为了实现Flux.1-Dev的端到端执行,我们可以应用CPU卸载,将不需要执行当前计算的组件卸载到CPU,以释放更多加速器内存。这样做可以在RTX 4090上以约22GB的内存运行整个管道,耗时**35.403秒**。启用编译可以将延迟降低到**31.205秒**(1.12倍加速)。在代码方面,只需几行:
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# Instead of full compilation, we apply regional compilation
# here to take advantage of `fullgraph=True` and also to reduce
# compilation time. More details can be found here:
# https://huggingface.co/docs/diffusers/main/en/optimization/fp16#regional-compilation
pipe.transformer.compile_repeated_blocks(fullgraph=True)
image = pipe(**pipe_kwargs).images[0]
请注意,我们在此处没有应用FP8量化,因为它不支持CPU卸载和编译(支持问题线程)。因此,仅将FP8量化应用于Flux Transformer不足以缓解内存耗尽问题。在这种情况下,我们决定将其移除。
因此,为了利用FP8量化方案,我们需要找到一种无需CPU卸载的方法。对于Flux.1-Dev,如果再对T5文本编码器进行量化,我们应该能够在24GB内存中加载和运行完整的管道。下面是T5文本编码器量化(来自bitsandbytes
的NF4量化)和未量化时的结果比较。
如上图所示,量化T5文本编码器并不会造成太大的质量损失。将量化后的T5文本编码器和FP8量化后的Flux Transformer与`torch.compile`结合使用,我们得到了相当不错的结果——从32.27秒降至**9.668秒**(大幅加速约3.3倍),且没有明显的质量下降。
即使不量化T5文本编码器,也可以用24GB的VRAM生成图像,但这会使我们的生成流程稍微复杂一些。
我们现在有了一种在RTX 4090上使用FP8量化运行整个Flux.1-Dev管道的方法。我们可以在相同的硬件上应用先前建立的优化LoRA推理方法。由于RTX 4090不支持FA3,我们将坚持以下优化方法,并新增T5量化:
- FP8量化
torch.compile
- 支持热插拔
- T5量化 (使用NF4)
在下表中,我们展示了应用上述组件不同组合的推理延迟数据。
选项 | 关键参数标志 | 时间 (秒) ⬇️ | 加速 (对比基线) ⬆️ |
---|---|---|---|
基准 | disable_fp8=False disable_compile=True quantize_t5=True offload=False |
23.6060 | – |
已优化 | disable_fp8=False disable_compile=False quantize_t5=True offload=False |
11.5715 | 2.04倍 |
简要说明:
- 编译比基线提供了巨大的2倍加速。
- 即使启用了卸载,其他选项也导致了OOM错误。
热插拔的技术细节
为了实现热插拔而不触发重新编译,必须克服两个障碍。首先,LoRA的缩放因子必须从浮点数转换为torch张量,这相对容易实现。其次,LoRA权重的形状需要填充到所需的最大形状。这样,可以替换权重中的数据而无需重新分配整个属性。这就是为什么上面讨论的`max_rank`参数至关重要。由于我们将值用零填充,结果保持不变,尽管计算速度会根据填充的大小而稍有减慢。
由于没有添加新的LoRA属性,这也要求第一个LoRA之后的每个LoRA只能针对第一个LoRA所针对的相同层或其子集。因此,请明智地选择加载顺序。如果LoRA针对不相交的层,则可以创建一个针对所有目标层并集的虚拟LoRA。
要查看此实现的详细信息,请访问PEFT中的`hotswap.py`文件。
结论
本文概述了一种用于Flux快速LoRA推理的优化方法,并展示了显著的加速效果。我们的方法结合了Flash Attention 3、`torch.compile`和FP8量化,同时确保了热插拔功能,避免了重新编译问题。在H100等高端GPU上,这种优化设置比基线提供了2.23倍的加速。
对于消费级GPU,特别是RTX 4090,我们通过引入T5文本编码器量化(NF4)和利用区域编译解决了内存限制。这种全面的方法实现了显著的2.04倍加速,即使在有限的VRAM下,也能使Flux上的LoRA推理变得可行且高效。关键在于,通过仔细管理编译和量化,LoRA的优势可以在不同的硬件配置上充分实现。
希望本文提供的秘诀能激发您优化基于LoRA的用例,从而受益于快速推理。
资源
以下是本文中引用的重要资源列表