(LoRA) 在消费级硬件上微调 FLUX.1-dev
在上一篇文章《探索 Diffusers 中的量化后端》中,我们深入探讨了各种量化技术如何缩小像 FLUX.1-dev 这样的扩散模型,使其在不大幅影响性能的情况下,大大提高了进行**推理**的可访问性。我们看到了 bitsandbytes
、torchao
等如何减少生成图像的内存占用。
执行推理很酷,但要真正让这些模型成为我们自己的,我们还需要能够微调它们。因此,在这篇文章中,我们将探讨**高效**地**微调**这些模型,单 GPU 峰值内存使用量低于约 10 GB 显存。本文将指导您使用 diffusers
库通过 QLoRA 微调 FLUX.1-dev。我们将展示 NVIDIA RTX 4090 的结果。我们还将强调如何通过 torchao
进行 FP8 训练,以在兼容硬件上进一步优化速度。
目录
- 数据集
- FLUX 架构
- 使用 diffusers 对 FLUX.1-dev 进行 QLoRA 微调
- 使用
torchao
进行 FP8 微调 - 使用训练好的 LoRA 适配器进行推理
- 在 Google Colab 上运行
- 结论
数据集
我们旨在微调 black-forest-labs/FLUX.1-dev
,使其采用阿尔丰斯·穆夏的艺术风格,使用一个小型数据集。
FLUX 架构
该模型由三个主要组件组成
- 文本编码器(CLIP 和 T5)
- Transformer(主模型 - Flux Transformer)
- 变分自编码器(VAE)
在我们的 QLoRA 方法中,我们**只**关注**微调 Transformer 组件**。文本编码器和 VAE 在整个训练过程中保持冻结状态。
使用 Diffusers 对 FLUX.1-dev 进行 QLoRA 微调
我们使用了一个 diffusers
训练脚本(在此处稍作修改,旨在用于 FLUX 模型的 DreamBooth 风格 LoRA 微调。此外,此处提供了一个简化版本,用于复现本博文中的结果(并在 Google Colab 中使用)。让我们检查 QLoRA 和内存效率的关键部分
关键优化技术
LoRA(低秩适应)深入探讨: LoRA 通过使用低秩矩阵跟踪权重更新,使模型训练更高效。LoRA 不会更新完整的权重矩阵 ,而是学习两个较小的矩阵 和 。模型权重的更新为 ,其中 和 。数字 (称为*秩*)远小于原始维度,这意味着需要更新的参数更少。最后, 是 LoRA 激活的缩放因子。它影响 LoRA 对更新的影响程度,通常设置为与 相同的值或其倍数。它有助于平衡预训练模型和 LoRA 适配器的影响。有关该概念的总体介绍,请查看我们之前的博文:《使用 LoRA 进行高效 Stable Diffusion 微调》。
QLoRA:效率利器: QLoRA 通过首先以量化格式(通常通过 bitsandbytes
以 4 位格式)加载预训练的基础模型来增强 LoRA,从而大幅削减基础模型的内存占用。然后,它在此量化基础之上训练 LoRA 适配器(通常为 FP16/BF16)。这显著降低了保存基础模型所需的显存。
例如,在HiDream 的 DreamBooth 训练脚本中,使用 bitsandbytes 进行 4 位量化将 LoRA 微调的峰值内存使用量从约 60GB 降低到约 37GB,而质量退化可忽略不计。我们在此处应用相同的原理来在消费级硬件上微调 FLUX.1。
8 位优化器 (AdamW): 标准 AdamW 优化器以 32 位(FP32)维护每个参数的一阶和二阶矩估计,这会消耗大量内存。8 位 AdamW 使用块级量化以 8 位精度存储优化器状态,同时保持训练稳定性。与标准 FP32 AdamW 相比,此技术可将优化器内存使用量减少约 75%。在脚本中启用它非常简单
# Check for the --use_8bit_adam flag
if args.use_8bit_adam:
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
梯度检查点: 在前向传播过程中,通常会存储中间激活以用于后向传播梯度计算。梯度检查点通过仅存储某些*检查点激活*并在反向传播期间重新计算其他激活来权衡计算与内存。
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
缓存潜变量: 这种优化技术在训练开始前,通过 VAE 编码器预处理所有训练图像。它将生成的潜变量表示存储在内存中。在训练期间,直接使用缓存的潜变量,而不是即时编码图像。这种方法提供两个主要好处
- 消除了训练过程中冗余的 VAE 编码计算,加快了每个训练步骤的速度
- 允许 VAE 在缓存后完全从 GPU 内存中移除。缺点是存储所有缓存的潜变量会增加 RAM 使用量,但这对于小型数据集通常是可控的。
# Cache latents before training if the flag is set
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
# VAE is no longer needed, free up its memory
del vae
free_memory()
设置 4 位量化 (BitsAndBytesConfig
)
本节演示了基础模型的 QLoRA 配置
# Determine compute dtype based on mixed precision
bnb_4bit_compute_dtype = torch.float32
if args.mixed_precision == "fp16":
bnb_4bit_compute_dtype = torch.float16
elif args.mixed_precision == "bf16":
bnb_4bit_compute_dtype = torch.bfloat16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)
transformer = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
quantization_config=nf4_config,
torch_dtype=bnb_4bit_compute_dtype,
)
# Prepare model for k-bit training
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
# Gradient checkpointing is enabled later via transformer.enable_gradient_checkpointing() if arg is set
定义 LoRA 配置 (LoraConfig
): 适配器被添加到量化后的 Transformer 中
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"], # FLUX attention blocks
)
transformer.add_adapter(transformer_lora_config)
print(f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}")
# trainable params: 4,669,440 || all params: 11,906,077,760
只有这些 LoRA 参数是可训练的。
预计算文本嵌入 (CLIP/T5)
在启动 QLoRA 微调之前,我们可以通过一次缓存文本编码器的输出来节省大量的显存和挂钟时间。
在训练时,数据加载器只需读取缓存的嵌入,而无需重新编码字幕,因此 CLIP/T5 编码器无需占用 GPU 内存。
代码
# https://github.com/huggingface/diffusers/blob/main/examples/research_projects/flux_lora_quantization/compute_embeddings.py
import argparse
import pandas as pd
import torch
from datasets import load_dataset
from huggingface_hub.utils import insecure_hashlib
from tqdm.auto import tqdm
from transformers import T5EncoderModel
from diffusers import FluxPipeline
MAX_SEQ_LENGTH = 77
OUTPUT_PATH = "embeddings.parquet"
def generate_image_hash(image):
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
def load_flux_dev_pipeline():
id = "black-forest-labs/FLUX.1-dev"
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
pipeline = FluxPipeline.from_pretrained(
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
)
return pipeline
@torch.no_grad()
def compute_embeddings(pipeline, prompts, max_sequence_length):
all_prompt_embeds = []
all_pooled_prompt_embeds = []
all_text_ids = []
for prompt in tqdm(prompts, desc="Encoding prompts."):
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
all_prompt_embeds.append(prompt_embeds)
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
all_text_ids.append(text_ids)
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f"Max memory allocated: {max_memory:.3f} GB")
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
def run(args):
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
all_prompts = list(image_prompts.values())
print(f"{len(all_prompts)=}")
pipeline = load_flux_dev_pipeline()
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
pipeline, all_prompts, args.max_sequence_length
)
data = []
for i, (image_hash, _) in enumerate(image_prompts.items()):
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
print(f"{len(data)=}")
# Create a DataFrame
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
print(f"{len(df)=}")
# Convert embedding lists to arrays (for proper storage in parquet)
for col in embedding_cols:
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
# Save the dataframe to a parquet file
df.to_parquet(args.output_path)
print(f"Data successfully serialized to {args.output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_sequence_length",
type=int,
default=MAX_SEQ_LENGTH,
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
)
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
args = parser.parse_args()
run(args)
如何使用
python compute_embeddings.py \
--max_sequence_length 77 \
--output_path embeddings_alphonse_mucha.parquet
通过将其与缓存的 VAE 潜变量 (--cache_latents
) 结合使用,您可以将活动模型缩减为仅包含量化后的 Transformer + LoRA 适配器,从而使整个微调过程在 10 GB 的 GPU 内存下舒适运行。
设置与结果
为了本次演示,我们利用了 NVIDIA RTX 4090 (24GB 显存) 来探索其性能。使用 accelerate
的完整训练命令如下所示。
# You need to pre-compute the text embeddings first. See the diffusers repo.
# https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization
accelerate launch --config_file=accelerate.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings_alphonse_mucha.parquet" \
--output_dir="alphonse_mucha_lora_flux_nf4" \
--mixed_precision="bf16" \
--use_8bit_adam \
--weighting_scheme="none" \
--width=512 \
--height=768 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \ # can drop checkpointing when HW has more than 16 GB.
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
RTX 4090 配置: 在我们的 RTX 4090 上,我们使用了 train_batch_size
为 1,gradient_accumulation_steps
为 4,mixed_precision="bf16"
,gradient_checkpointing=True
,use_8bit_adam=True
,LoRA rank
为 4,分辨率为 512x768。潜变量通过 cache_latents=True
进行缓存。
内存占用 (RTX 4090)
- QLoRA: QLoRA 微调的峰值显存使用量约为 9GB。
- BF16 LoRA: 在相同设置下运行标准 LoRA(基础 FLUX.1-dev 为 FP16)消耗 26 GB 显存。
- BF16 全量微调: 在不进行内存优化的前提下,估计大约需要 120 GB 显存。
训练时间 (RTX 4090): 在 RTX 4090 上,使用 train_batch_size
为 1,分辨率为 512x768,对阿尔丰斯·穆夏数据集进行 700 步的微调大约需要 41 分钟。
输出质量: 最终的衡量标准是生成的艺术作品。以下是我们在 derekl35/alphonse-mucha-style 数据集上使用 QLoRA 微调模型的样本
此表比较了主要的 bf16
精度结果。微调的目标是让模型学习阿尔丰斯·穆夏独特的风格。
提示 | 基础模型输出 | QLoRA 微调输出(穆夏风格) |
---|---|---|
“宁静的黑发女人,月光下的百合,旋涡状的植物图案,阿尔丰斯·穆夏风格” | ![]() |
![]() |
“池塘里的小狗,阿尔丰斯·穆夏风格” | ![]() |
![]() |
“华丽的狐狸,戴着秋叶和浆果的项圈,置身于森林树叶的挂毯之中,阿尔丰斯·穆夏风格” | ![]() |
![]() |
微调后的模型很好地捕捉了穆夏标志性的新艺术风格,这从装饰图案和独特的调色板中显而易见。QLoRA 过程在学习新风格的同时保持了出色的保真度。
使用 TorchAO 进行 FP8 微调
对于拥有计算能力 8.9 或更高(例如 H100、RTX 4090)的 NVIDIA GPU 用户,可以通过 torchao
库利用 FP8 训练实现更高的速度效率。
我们使用略微修改的 diffusers-torchao
训练脚本,在 H100 SXM GPU 上对 FLUX.1-dev LoRA 进行了微调。使用的命令如下
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-dev \
--dataset_name=derekl35/alphonse-mucha-style --instance_prompt="a woman, alphonse mucha style" --caption_column="text" \
--output_dir=alphonse_mucha_fp8_lora_flux \
--mixed_precision=bf16 --use_8bit_adam \
--weighting_scheme=none \
--height=768 --width=512 --train_batch_size=1 --repeats=1 \
--learning_rate=1e-4 --guidance_scale=1 --report_to=wandb \
--gradient_accumulation_steps=1 --gradient_checkpointing \
--lr_scheduler=constant --lr_warmup_steps=0 --rank=4 \
--max_train_steps=700 --checkpointing_steps=600 --seed=0 \
--do_fp8_training --push_to_hub
训练运行时,**峰值内存使用量为 36.57 GB**,并在大约 **20 分钟**内完成。
启用 FP8 训练与 torchao
的关键步骤包括
- 使用
torchao.float8
中的convert_to_float8_training
将 FP8 层**注入**模型。 - **定义
module_filter_fn
** 来指定哪些模块应该转换为 FP8,哪些不应该。
如需更详细的指南和代码片段,请参阅此要点和diffusers-torchao
存储库。
使用训练好的 LoRA 适配器进行推理
训练完 LoRA 适配器后,您有两种主要的推理方法。
选项 1:加载 LoRA 适配器
一种方法是在基础模型之上加载您训练好的 LoRA 适配器。
加载 LoRA 的好处
- 灵活性: 无需重新加载基础模型即可轻松切换不同的 LoRA 适配器
- 实验: 通过交换适配器来测试多种艺术风格或概念
- 模块化: 使用
set_adapters()
组合多个 LoRA 适配器以实现创意融合 - 存储效率: 维护一个基础模型和多个小型适配器文件
代码
from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline = FluxPipeline.from_pretrained(
ckpt_id, torch_dtype=torch.float16
)
pipeline.load_lora_weights("derekl35/alphonse_mucha_qlora_flux", weight_name="pytorch_lora_weights.safetensors")
pipeline.enable_model_cpu_offload()
image = pipeline(
"a puppy in a pond, alphonse mucha style", num_inference_steps=28, guidance_scale=3.5, height=768, width=512, generator=torch.manual_seed(0)
).images[0]
image.save("alphonse_mucha.png")
选项 2:将 LoRA 合并到基础模型中
当您想要以单一风格实现最大效率时,可以将LoRA 权重合并到基础模型中。
合并 LoRA 的好处
- 显存效率: 推理过程中没有适配器权重的额外内存开销
- 速度: 推理速度略快,因为无需执行适配器计算
- 量化兼容性: 可以对合并后的模型重新量化,以实现最大内存效率
代码
from diffusers import FluxPipeline, AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline = FluxPipeline.from_pretrained(
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
)
pipeline.load_lora_weights("derekl35/alphonse_mucha_qlora_flux", weight_name="pytorch_lora_weights.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.transformer.save_pretrained("fused_transformer")
bnb_4bit_compute_dtype = torch.bfloat16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)
transformer = FluxTransformer2DModel.from_pretrained(
"fused_transformer",
quantization_config=nf4_config,
torch_dtype=bnb_4bit_compute_dtype,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
)
pipeline.enable_model_cpu_offload()
image = pipeline(
"a puppy in a pond, alphonse mucha style", num_inference_steps=28, guidance_scale=3.5, height=768, width=512, generator=torch.manual_seed(0)
).images[0]
image.save("alphonse_mucha_merged.png")
在 Google Colab 上运行
虽然我们展示了 RTX 4090 上的结果,但相同的代码可以在更易于访问的硬件上运行,例如 Google Colab 中免费提供的 T4 GPU。
在 T4 上,相同的步数下,微调过程预计会显著延长,大约需要 4 小时。这是为了可访问性而做出的权衡,但它使得无需高端硬件即可进行自定义微调成为可能。如果在 Colab 上运行,请注意使用限制,因为 4 小时的训练运行可能会超出限制。
结论
QLoRA 与 diffusers
库相结合,极大地普及了定制 FLUX.1-dev 等最先进模型的能力。正如在 RTX 4090 上所演示的,高效微调唾手可得,并能产生高质量的风格适应。此外,对于拥有最新 NVIDIA 硬件的用户,torchao
通过 FP8 精度实现了更快的训练。
在 Hub 上分享您的创作!
分享您微调的 LoRA 适配器是为开源社区做出贡献的绝佳方式。它让其他人可以轻松尝试您的风格,在您的工作基础上继续发展,并有助于创建充满活力的创意 AI 工具生态系统。
如果您已经训练了 FLUX.1-dev 的 LoRA,我们鼓励您分享它。最简单的方法是将 --push_to_hub 标志添加到训练脚本中。另外,如果您已经训练了一个模型并希望上传它,您可以使用以下代码片段。
# Prereqs:
# - pip install huggingface_hub diffusers
# - Run `huggingface-cli login` (or set HF_TOKEN env-var) once.
# - save model
from huggingface_hub import create_repo, upload_folder
repo_id = "your-username/alphonse_mucha_qlora_flux"
create_repo(repo_id, exist_ok=True)
upload_folder(
repo_id=repo_id,
folder_path="alphonse_mucha_qlora_flux",
commit_message="Add Alphonse Mucha LoRA adapter"
)
查看我们的穆夏 LoRA 和 TorchAO FP8 LoRA。您可以在此集合中找到这两者以及其他适配器。
我们迫不及待地想看到您的创作!