使用 Quanto 和 Diffusers 实现内存高效的 Diffusion Transformer
在过去几个月中,我们见证了基于 Transformer 的扩散模型骨干网络在高分辨率文本到图像 (T2I) 生成领域的兴起。这些模型使用 Transformer 架构作为扩散过程的基础模块,取代了许多早期扩散模型中普遍使用的 UNet 架构。得益于 Transformer 的特性,这些骨干网络表现出良好的可扩展性,模型参数从 6 亿到 80 亿不等。
随着模型越来越大,内存需求也随之增加。这个问题更加严峻,因为一个扩散 pipeline 通常由几个组件组成:一个文本编码器、一个扩散模型骨干和一个图像解码器。此外,现代扩散 pipeline 使用多个文本编码器——例如,Stable Diffusion 3 就有三个。使用 FP16 精度运行 SD3 推理需要 18.765 GB 的 GPU 内存。
这些高内存需求使得在消费级 GPU 上使用这些模型变得困难,从而减缓了它们的普及和实验进程。在本文中,我们将展示如何利用 Diffusers 库中 Quanto 的量化工具来提高基于 Transformer 的扩散 pipeline 的内存效率。
目录
预备知识
有关 Quanto 的详细介绍,请参阅这篇文章。简而言之,Quanto 是一个基于 PyTorch 构建的量化工具包。它是 Hugging Face Optimum 的一部分,这是一套用于硬件优化的工具。
模型量化是 LLM 从业者中一种流行的工具,但在扩散模型中却不那么常见。Quanto 可以帮助弥补这一差距,以极少或无质量下降为代价,实现内存节省。
为了进行基准测试,我们使用 H100 GPU,环境如下:
除非另有说明,我们默认使用 FP16 进行计算。我们选择不对 VAE 进行量化,以防止出现数值不稳定问题。我们的基准测试代码可以在这里找到。
在撰写本文时,我们在 Diffusers 中有以下基于 Transformer 的用于文本到图像生成的扩散 pipeline:
我们还有 Latte,一个基于 Transformer 的文本到视频生成 pipeline。
为简洁起见,我们的研究仅限于以下三个:PixArt-Sigma、Stable Diffusion 3 和 Aura Flow。下表显示了它们扩散模型骨干的参数数量:
模型 | 模型权重 | # 参数 (十亿) |
---|---|---|
PixArt | https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 0.611 |
Stable Diffusion 3 | https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers | 2.028 |
Aura Flow | https://huggingface.co/fal/AuraFlow/ | 6.843 |
使用 Quanto 量化 DiffusionPipeline
使用 Quanto 量化模型非常直接。
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import PixArtSigmaPipeline
import torch
pipeline = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")
quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)
我们对要量化的模块调用 `quantize()`,并指定要量化的内容。在上面的例子中,我们只量化参数,保持激活值不变。我们将参数量化到 FP8 数据类型。最后,我们调用 `freeze()` 来用量化后的参数替换原始参数。
然后我们可以正常调用这个 `pipeline`:
image = pipeline("ghibli style, a fantasy landscape with castles").images[0]
FP16 | FP8 的 Diffusion Transformer |
---|---|
![]() |
![]() |
我们注意到,使用 FP8 时,内存消耗减少,延迟略有增加,但质量几乎没有下降:
批量大小 | 量化 | 内存 (GB) | 延迟 (秒) |
---|---|---|---|
1 | 无 | 12.086 | 1.200 |
1 | FP8 | 11.547 | 1.540 |
4 | 无 | 12.087 | 4.482 |
4 | FP8 | 11.548 | 5.109 |
我们可以用同样的方法量化文本编码器:
quantize(pipeline.text_encoder, weights=qfloat8)
freeze(pipeline.text_encoder)
文本编码器也是一个 Transformer 模型,我们也可以对其进行量化。同时量化文本编码器和扩散模型骨干可以带来更大的内存改进:
批量大小 | 量化 | 量化文本编码器 (TE) | 内存 (GB) | 延迟 (秒) |
---|---|---|---|---|
1 | FP8 | 否 (False) | 11.547 | 1.540 |
1 | FP8 | 是 (True) | 5.363 | 1.601 |
4 | FP8 | 否 (False) | 11.548 | 5.109 |
4 | FP8 | 是 (True) | 5.364 | 5.141 |
量化文本编码器产生的结果与之前的情况非常相似:
观察的普适性
将文本编码器与扩散模型骨干一起量化通常适用于我们尝试过的模型。Stable Diffusion 3 是一个特例,因为它使用了三个不同的文本编码器。我们发现量化*第二个*文本编码器效果不佳,因此我们建议采用以下替代方案:
- 只量化第一个文本编码器 (
CLIPTextModelWithProjection
) 或 - 只量化第三个文本编码器 (
T5EncoderModel
) 或 - 量化第一个和第三个文本编码器
下表给出了各种文本编码器量化组合(所有情况下扩散 Transformer 都被量化)的预期内存节省情况:
批量大小 | 量化 | 量化 TE 1 | 量化 TE 2 | 量化 TE 3 | 内存 (GB) | 延迟 (秒) |
---|---|---|---|---|---|---|
1 | FP8 | 1 | 1 | 1 | 8.200 | 2.858 |
1 ✅ | FP8 | 0 | 0 | 1 | 8.294 | 2.781 |
1 | FP8 | 1 | 1 | 0 | 14.384 | 2.833 |
1 | FP8 | 0 | 1 | 0 | 14.475 | 2.818 |
1 ✅ | FP8 | 1 | 0 | 0 | 14.384 | 2.730 |
1 | FP8 | 0 | 1 | 1 | 8.325 | 2.875 |
1 ✅ | FP8 | 1 | 0 | 1 | 8.204 | 2.789 |
1 | 无 | - | - | - | 16.403 | 2.118 |
量化的文本编码器:1 | 量化的文本编码器:3 | 量化的文本编码器:1 和 3 |
---|---|---|
![]() |
![]() |
![]() |
其他发现
在 H100 上 bfloat16
通常更好
对于支持的 GPU 架构,如 H100 或 4090,使用 bfloat16
可能会更快。下表展示了在我们的 H100 参考硬件上测量的 PixArt 的一些数据:
批量大小 | 精度 | 量化 | 内存 (GB) | 延迟 (秒) | 量化文本编码器 (TE) |
---|---|---|---|---|---|
1 | FP16 | INT8 | 5.363 | 1.538 | 是 (True) |
1 | BF16 | INT8 | 5.364 | 1.454 | 是 (True) |
1 | FP16 | FP8 | 5.363 | 1.601 | 是 (True) |
1 | BF16 | FP8 | 5.363 | 1.495 | 是 (True) |
qint8
的前景
我们发现,使用 `qint8`(而不是 `qfloat8`)进行量化在推理延迟方面通常更好。当我们水平融合注意力 QKV 投影(在 Diffusers 中调用 `fuse_qkv_projections()`)时,这种效果会更加明显,从而加厚了 int8 内核的维度以加速计算。我们下面为 PixArt 提供了一些证据:
批量大小 | 量化 | 内存 (GB) | 延迟 (秒) | 量化文本编码器 (TE) | QKV 投影 |
---|---|---|---|---|---|
1 | INT8 | 5.363 | 1.538 | 是 (True) | 否 (False) |
1 | INT8 | 5.536 | 1.504 | 是 (True) | 是 (True) |
4 | INT8 | 5.365 | 5.129 | 是 (True) | 否 (False) |
4 | INT8 | 5.538 | 4.989 | 是 (True) | 是 (True) |
INT4 效果如何?
我们还额外试验了在使用 `bfloat16` 时采用 `qint4`。这仅适用于 H100 上的 `bfloat16`,因为其他配置尚不支持。使用 `qint4`,我们可以预期在内存消耗方面看到更多改进,但代价是推理延迟增加。延迟增加是预料之中的,因为没有原生硬件支持 int4 计算——权重使用 4 位传输,但计算仍然以 `bfloat16` 完成。下表显示了我们对 PixArt-Sigma 的结果:
批量大小 | 量化文本编码器 (TE) | 内存 (GB) | 延迟 (秒) |
---|---|---|---|
1 | 否 | 9.380 | 7.431 |
1 | 是 | 3.058 | 7.604 |
然而,请注意,由于 INT4 的激进离散化,最终结果可能会受到影响。这就是为什么,对于基于 Transformer 的模型,我们通常将最终的投影层排除在量化之外。在 Quanto 中,我们这样做:
quantize(pipeline.transformer, weights=qint4, exclude="proj_out")
freeze(pipeline.transformer)
"proj_out"
对应于 pipeline.transformer
中的最后一层。下表展示了各种设置下的结果:
量化 TE:否,层排除:无 | 量化 TE:否,层排除:"proj_out" | 量化 TE:是,层排除:无 | 量化 TE:是,层排除:"proj_out" |
---|---|---|---|
![]() |
![]() |
![]() |
![]() |
为了恢复损失的图像质量,一种常见的做法是进行量化感知训练,Quanto 也支持这种训练。该技术超出了本文的范围,如果您感兴趣,请随时与我们联系!
我们本次实验的所有结果都可以在这里找到。
附赠 - 在 Quanto 中保存和加载 Diffusers 模型
量化后的 Diffusers 模型可以被保存和加载:
from diffusers import PixArtTransformer2DModel
from optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8
model = PixArtTransformer2DModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="transformer")
qmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8)
qmodel.save_pretrained("pixart-sigma-fp8")
生成的模型权重大小为 *587MB*,而不是原来的 2.44GB。然后我们可以加载它:
from optimum.quanto import QuantizedPixArtTransformer2DModel
import torch
transformer = QuantizedPixArtTransformer2DModel.from_pretrained("pixart-sigma-fp8")
transformer.to(device="cuda", dtype=torch.float16)
并在 `DiffusionPipeline` 中使用它:
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
transformer=None,
torch_dtype=torch.float16,
).to("cuda")
pipe.transformer = transformer
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
未来,我们期望在初始化 pipeline 时可以直接传递 `transformer`,这样就可以这样工作了:
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- transformer=None,
+ transformer=transformer,
torch_dtype=torch.float16,
).to("cuda")
QuantizedPixArtTransformer2DModel
的实现可在此处参考。如果您希望在 Quanto 中支持更多 Diffusers 模型以便于保存和加载,请在此处提交一个 issue 并提及 `@sayakpaul`。
技巧
- 根据您的需求,您可能希望对不同的 pipeline 模块应用不同类型的量化。例如,您可以对文本编码器使用 FP8,但对扩散 Transformer 使用 INT8。得益于 Diffusers 和 Quanto 的灵活性,这可以无缝完成。
- 为了优化您的用例,您甚至可以将量化与 Diffusers 中的其他内存优化技术相结合,例如 `enable_model_cpu_offload()`。
结论
在本文中,我们展示了如何量化 Diffusers 中的 Transformer 模型并优化其内存消耗。当我们额外量化混合中的文本编码器时,量化的效果变得更加明显。我们希望您能将一些工作流程应用到您的项目中并从中受益 🤗。
感谢 Pedro Cuenca 对本文的详尽审阅。