Diffusers 文档

自定义扩散 (Custom Diffusion)

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

自定义扩散 (Custom Diffusion)

Custom Diffusion 是一种用于个性化图像生成模型的训练技术。与 Textual Inversion、DreamBooth 和 LoRA 类似,Custom Diffusion 只需少量(约 4-5 个)示例图像。这项技术通过仅训练交叉注意力层中的权重来实现,并使用一个特殊单词来表示新学习的概念。Custom Diffusion 的独特之处在于它还可以同时学习多个概念。

如果您在 vRAM 有限的 GPU 上进行训练,您应该尝试使用 --enable_xformers_memory_efficient_attention 启用 xFormers,以实现更快、vRAM 需求更低(16GB)的训练。为了节省更多内存,在训练参数中添加 --set_grads_to_none 将梯度设置为 None 而非零(此选项可能会导致一些问题,如果您遇到任何问题,请尝试删除此参数)。

本指南将探讨 train_custom_diffusion.py 脚本,帮助您更熟悉它,以及如何将其应用于您自己的用例。

在运行脚本之前,请确保从源代码安装库

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

导航到包含训练脚本的示例文件夹并安装所需的依赖项

cd examples/custom_diffusion
pip install -r requirements.txt
pip install clip-retrieval

🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它将根据您的硬件和环境自动配置您的训练设置。请查看 🤗 Accelerate 快速入门 以了解更多信息。

初始化 🤗 Accelerate 环境

accelerate config

要设置默认的 🤗 Accelerate 环境而不选择任何配置

accelerate config default

或者如果您的环境不支持交互式 shell(例如笔记本),您可以使用

from accelerate.utils import write_basic_config

write_basic_config()

最后,如果您想在自己的数据集上训练模型,请查看 创建训练数据集 指南,了解如何创建与训练脚本兼容的数据集。

以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的各个方面。如果您有兴趣了解更多信息,请随时通读 脚本,并告诉我们您是否有任何问题或疑虑。

脚本参数

训练脚本包含所有参数,可帮助您自定义训练运行。这些参数在 parse_args() 函数中找到。该函数带有默认值,但您也可以在训练命令中设置自己的值。

例如,更改输入图像的分辨率

accelerate launch train_custom_diffusion.py \
  --resolution=256

许多基本参数已在 DreamBooth 训练指南中介绍,因此本指南重点介绍 Custom Diffusion 特有的参数。

  • --freeze_model:冻结交叉注意力层中的键和值参数;默认为 crossattn_kv,但您可以将其设置为 crossattn 以训练交叉注意力层中的所有参数。
  • --concepts_list:要学习多个概念,请提供一个包含概念的 JSON 文件的路径。
  • --modifier_token:用于表示学习到的概念的特殊单词。
  • --initializer_token:用于初始化 modifier_token 嵌入的特殊单词。

先验保留损失 (Prior preservation loss)

先验保留损失是一种利用模型自身生成的样本来帮助它学习如何生成更多样化图像的方法。由于这些生成的样本图像与您提供的图像属于同一类别,它们有助于模型保留其已学习到的关于该类别的信息,以及如何利用这些信息来创建新的组合。

先验保留损失的许多参数已在 DreamBooth 训练指南中介绍。

正则化 (Regularization)

Custom Diffusion 包括使用一小组真实图像训练目标图像,以防止过拟合。正如您所想象的,当您只训练少量图像时,这很容易发生!使用 clip_retrieval 下载 200 张真实图像。class_prompt 应与目标图像属于同一类别。这些图像存储在 class_data_dir 中。

python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200

要启用正则化,请添加以下参数

  • --with_prior_preservation:是否使用先验保留损失。
  • --prior_loss_weight:控制先验保留损失对模型的影响。
  • --real_prior:是否使用一小组真实图像来防止过拟合。
accelerate launch train_custom_diffusion.py \
  --with_prior_preservation \
  --prior_loss_weight=1.0 \
  --class_data_dir="./real_reg/samples_cat" \
  --class_prompt="cat" \
  --real_prior=True \

训练脚本

Custom Diffusion 训练脚本中的许多代码与 DreamBooth 脚本相似。本指南将重点介绍与 Custom Diffusion 相关的代码。

Custom Diffusion 训练脚本有两个数据集类

接下来,modifier_token添加到分词器中,转换为 token id,并且 token 嵌入的大小被调整以适应新的 modifier_token。然后,modifier_token 嵌入用 initializer_token 的嵌入进行初始化。文本编码器中的所有参数都被冻结,除了 token 嵌入,因为这是模型试图学习与概念关联的内容。

params_to_freeze = itertools.chain(
    text_encoder.text_model.encoder.parameters(),
    text_encoder.text_model.final_layer_norm.parameters(),
    text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)

现在您需要将 Custom Diffusion 权重添加到注意力层。这是正确设置注意力权重形状和大小,以及在每个 UNet 块中设置适当数量的注意力处理器非常重要的一步。

st = unet.state_dict()
for name, _ in unet.attn_processors.items():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
    layer_name = name.split(".processor")[0]
    weights = {
        "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
        "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
    }
    if train_q_out:
        weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
        weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
        weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
    if cross_attention_dim is not None:
        custom_diffusion_attn_procs[name] = attention_class(
            train_kv=train_kv,
            train_q_out=train_q_out,
            hidden_size=hidden_size,
            cross_attention_dim=cross_attention_dim,
        ).to(unet.device)
        custom_diffusion_attn_procs[name].load_state_dict(weights)
    else:
        custom_diffusion_attn_procs[name] = attention_class(
            train_kv=False,
            train_q_out=False,
            hidden_size=hidden_size,
            cross_attention_dim=cross_attention_dim,
        )
del st
unet.set_attn_processor(custom_diffusion_attn_procs)
custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)

优化器被初始化,用于更新交叉注意力层参数。

optimizer = optimizer_class(
    itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())
    if args.modifier_token is not None
    else custom_diffusion_layers.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

训练循环中,重要的是只更新您正在学习的概念的嵌入。这意味着将所有其他 token 嵌入的梯度设置为零。

if args.modifier_token is not None:
    if accelerator.num_processes > 1:
        grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
    else:
        grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
    index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
    for i in range(len(modifier_token_id[1:])):
        index_grads_to_zero = index_grads_to_zero & (
            torch.arange(len(tokenizer)) != modifier_token_id[i]
        )
    grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
        index_grads_to_zero, :
    ].fill_(0)

启动脚本

完成所有更改或对默认配置满意后,您就可以启动训练脚本了!🚀

在本指南中,您将下载并使用这些示例猫图像。您也可以根据需要创建和使用自己的数据集(请参阅创建训练数据集指南)。

将环境变量 MODEL_NAME 设置为 Hub 上的模型 ID 或本地模型的路径,将 INSTANCE_DIR 设置为您刚刚下载猫图像的路径,并将 OUTPUT_DIR 设置为您要保存模型的路径。您将使用 <new1> 作为特殊词,将新学习的嵌入与其关联。脚本会创建并保存模型检查点和一个 pytorch_custom_diffusion_weights.bin 文件到您的存储库中。

要使用 Weights and Biases 监控训练进度,请在训练命令中添加 --report_to=wandb 参数,并使用 --validation_prompt 指定验证提示词。这对于调试和保存中间结果很有用。

如果您正在训练人脸,Custom Diffusion 团队发现以下参数效果良好

  • --learning_rate=5e-6
  • --max_train_steps 可以在 1000 到 2000 之间
  • --freeze_model=crossattn
  • 至少使用 15-20 张图像进行训练
单个概念
多个概念
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
export INSTANCE_DIR="./data/cat"

accelerate launch train_custom_diffusion.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --class_data_dir=./real_reg/samples_cat/ \
  --with_prior_preservation \
  --real_prior \
  --prior_loss_weight=1.0 \
  --class_prompt="cat" \
  --num_class_images=200 \
  --instance_prompt="photo of a <new1> cat"  \
  --resolution=512  \
  --train_batch_size=2  \
  --learning_rate=1e-5  \
  --lr_warmup_steps=0 \
  --max_train_steps=250 \
  --scale_lr \
  --hflip  \
  --modifier_token "<new1>" \
  --validation_prompt="<new1> cat sitting in a bucket" \
  --report_to="wandb" \
  --push_to_hub

训练完成后,您可以使用新的 Custom Diffusion 模型进行推理。

单个概念
多个概念
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16,
).to("cuda")
pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
pipeline.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")

image = pipeline(
    "<new1> cat sitting in a bucket",
    num_inference_steps=100,
    guidance_scale=6.0,
    eta=1.0,
).images[0]
image.save("cat.png")

下一步

恭喜您使用 Custom Diffusion 训练了一个模型!🎉 要了解更多信息

< > 在 GitHub 上更新