Diffusers 文档
torchao
并获得增强的文档体验
开始使用
torchao
TorchAO 是一个用于 PyTorch 的架构优化库。它为推理和训练提供高性能的数据类型、优化技术和内核,并具有与原生 PyTorch 功能(如 torch.compile、FullyShardedDataParallel (FSDP) 等)的可组合性。
在开始之前,请确保您已安装 Pytorch 2.5+ 和 TorchAO。
pip install -U torch torchao
通过将 TorchAoConfig 传递给 from_pretrained() 来量化模型(您也可以加载预量化的模型)。这适用于任何模态的任何模型,只要它支持使用 Accelerate 加载并且包含 torch.nn.Linear
层。
以下示例仅将权重量化为 int8。
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
TorchAO 与 torch.compile 完全兼容,这使其与其他量化方法区分开来。这使得只需一行代码即可轻松加速推理。
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
有关 Flux 和 CogVideoX 的速度和内存基准测试,请参阅此处的表格。您还可以找到一些 torchao 基准测试 数字,适用于各种硬件。
torchao 还通过 autoquant 支持自动量化 API。自动量化通过比较每种技术在选定输入类型和形状上的性能,来确定适用于模型的最佳量化策略。目前,这可以直接在底层建模组件上使用。Diffusers 未来也将在配置选项中公开 autoquant。
TorchAoConfig
类接受三个参数
quant_type
:一个字符串值,提及以下量化类型之一。modules_to_not_convert
:一个模块完整/部分模块名称列表,对于这些模块不应执行量化。例如,为了不对 FluxTransformer2DModel 的第一个块执行任何量化,可以指定:modules_to_not_convert=["single_transformer_blocks.0"]
。kwargs
:一个关键字参数字典,用于传递给将根据quant_type
调用的底层量化方法。
支持的量化类型
torchao 支持仅权重量化以及 int8、float3-float8 和 uint1-uint7 的权重和动态激活量化。
仅权重量化将模型权重存储在特定的低比特数据类型中,但使用更高精度的数据类型(如 bfloat16
)执行计算。这降低了模型权重的内存需求,但保留了激活计算的内存峰值。
动态激活量化将模型权重存储在低比特 dtype 中,同时动态量化激活以节省更多内存。这降低了模型权重的内存需求,同时也降低了激活计算的内存开销。但是,这有时可能会以质量为代价,因此建议彻底测试不同的模型。
支持的量化方法如下
类别 | 完整函数名称 | 简写 |
---|---|---|
整数量化 | int4_weight_only , int8_dynamic_activation_int4_weight , int8_weight_only , int8_dynamic_activation_int8_weight | int4wo , int4dq , int8wo , int8dq |
浮点 8 比特量化 | float8_weight_only , float8_dynamic_activation_float8_weight , float8_static_activation_float8_weight | float8wo , float8wo_e5m2 , float8wo_e4m3 , float8dq , float8dq_e4m3 , float8_e4m3_tensor , float8_e4m3_row |
浮点 X 比特量化 | fpx_weight_only | fpX_eAwB ,其中 X 是比特数(1-7),A 是指数比特,B 是尾数比特。约束条件:X == A + B + 1 |
无符号整数量化 | uintx_weight_only | uint1wo , uint2wo , uint3wo , uint4wo , uint5wo , uint6wo , uint7wo |
某些量化方法是别名(例如,int8wo
是常用的 int8_weight_only
的简写)。这允许按原样使用 torchao 文档中描述的量化方法,同时也方便记住它们的简写符号。
有关可用量化方法和可用配置选项的详尽列表的更好理解,请参阅官方 torchao 文档。
序列化和反序列化量化模型
要在给定的 dtype 中序列化量化模型,请首先使用所需的量化 dtype 加载模型,然后使用 save_pretrained() 方法保存它。
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
要加载序列化的量化模型,请使用 from_pretrained() 方法。
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
某些量化方法(如 uint4wo
)无法直接加载,并且在尝试加载模型时可能会导致 UnpicklingError
,但在保存模型时可以按预期工作。为了解决这个问题,可以将状态字典手动加载到模型中。但请注意,这需要在 torch.load
中使用 weights_only=False
,因此仅应在权重来自可信来源时运行。
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)