Diffusers 文档

文本反演

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

文本反演

文本反演 是一种训练技术,仅需少量您想要模型学习内容的示例图像,即可个性化图像生成模型。此技术的工作原理是学习和更新文本嵌入(新的嵌入与您必须在提示中使用的特殊词语相关联),以匹配您提供的示例图像。

如果您在 vRAM 有限的 GPU 上进行训练,则应尝试在训练命令中启用 gradient_checkpointingmixed_precision 参数。您还可以通过使用带有 xFormers 的内存高效注意力来减少内存占用。JAX/Flax 训练也支持在 TPU 和 GPU 上进行高效训练,但它不支持梯度检查点或 xFormers。使用与 PyTorch 相同的配置和设置,Flax 训练脚本应该至少快约 70%!

本指南将探讨 textual_inversion.py 脚本,以帮助您更熟悉它,以及如何针对您自己的用例进行调整。

在运行脚本之前,请确保从源码安装库

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

导航到包含训练脚本的示例文件夹,并为您正在使用的脚本安装所需的依赖项

PyTorch
Flax
cd examples/textual_inversion
pip install -r requirements.txt

🤗 Accelerate 是一个库,可帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它将根据您的硬件和环境自动配置您的训练设置。请查看 🤗 Accelerate 快速入门 以了解更多信息。

初始化 🤗 Accelerate 环境

accelerate config

要设置默认的 🤗 Accelerate 环境而无需选择任何配置

accelerate config default

或者,如果您的环境不支持交互式 shell,例如 notebook,您可以使用

from accelerate.utils import write_basic_config

write_basic_config()

最后,如果您想在自己的数据集上训练模型,请查看 创建用于训练的数据集 指南,了解如何创建与训练脚本配合使用的数据集。

以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但它并未详细介绍脚本的每个方面。如果您有兴趣了解更多信息,请随时通读 script,如果您有任何问题或疑虑,请告诉我们。

脚本参数

训练脚本有许多参数,可帮助您根据自己的需求定制训练运行。所有参数及其描述都列在 parse_args() 函数中。在适用的情况下,Diffusers 为每个参数提供默认值,例如训练批次大小和学习率,但如果您愿意,可以在训练命令中随意更改这些值。

例如,要将梯度累积步数增加到高于默认值 1

accelerate launch textual_inversion.py \
  --gradient_accumulation_steps=4

其他一些基本且重要的参数包括:

  • --pretrained_model_name_or_path:Hub 上的模型名称或预训练模型的本地路径
  • --train_data_dir:包含训练数据集(示例图像)的文件夹路径
  • --output_dir:保存训练后模型的位置
  • --push_to_hub:是否将训练后的模型推送到 Hub
  • --checkpointing_steps:在模型训练时保存检查点的频率;如果由于某种原因训练中断,这将非常有用,您可以通过在训练命令中添加 --resume_from_checkpoint 从该检查点继续训练
  • --num_vectors:用于学习嵌入的向量数;增加此参数有助于模型更好地学习,但会增加训练成本
  • --placeholder_token:将学习到的嵌入与之关联的特殊词语(您必须在推理的提示中使用该词语)
  • --initializer_token:一个大致描述您尝试训练的对象或风格的单字
  • --learnable_property:您是否正在训练模型学习新的“风格”(例如,梵高的绘画风格)或“对象”(例如,您的狗)

训练脚本

与某些其他训练脚本不同,textual_inversion.py 有一个自定义数据集类 TextualInversionDataset 用于创建数据集。您可以自定义图像大小、占位符令牌、插值方法、是否裁剪图像等。如果您需要更改数据集的创建方式,可以修改 TextualInversionDataset

接下来,您将在 main() 函数中找到数据集预处理代码和训练循环。

该脚本首先加载 tokenizerscheduler 和模型

# Load tokenizer
if args.tokenizer_name:
    tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

特殊的 占位符令牌 接下来被添加到 tokenizer 中,并且嵌入被重新调整以适应新令牌。

然后,脚本 TextualInversionDataset 创建数据集

train_dataset = TextualInversionDataset(
    data_root=args.train_data_dir,
    tokenizer=tokenizer,
    size=args.resolution,
    placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
    repeats=args.repeats,
    learnable_property=args.learnable_property,
    center_crop=args.center_crop,
    set="train",
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)

最后,训练循环 处理从预测噪声残差到更新特殊占位符令牌的嵌入权重的所有其他操作。

如果您想了解有关训练循环工作原理的更多信息,请查看 理解 pipelines、models 和 schedulers 教程,其中分解了去噪过程的基本模式。

启动脚本

一旦您完成所有更改或对默认配置感到满意,您就可以启动训练脚本了!🚀

在本指南中,您将下载一些 猫玩具 的图像并将它们存储在一个目录中。但请记住,如果您愿意,您可以创建和使用自己的数据集(请参阅 创建用于训练的数据集 指南)。

from huggingface_hub import snapshot_download

local_dir = "./cat"
snapshot_download(
    "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)

将环境变量 MODEL_NAME 设置为 Hub 上的模型 ID 或本地模型的路径,并将 DATA_DIR 设置为您刚刚下载猫图像的路径。该脚本会创建以下文件并将其保存到您的存储库中

  • learned_embeds.bin:与您的示例图像对应的学习到的嵌入向量
  • token_identifier.txt:特殊占位符令牌
  • type_of_concept.txt:您正在训练的概念类型(“对象”或“风格”)

在单个 V100 GPU 上,完整的训练运行需要约 1 小时。

在启动脚本之前还有一件事。如果您有兴趣跟进训练过程,您可以定期保存生成的图像,随着训练的进行。将以下参数添加到训练命令中

--validation_prompt="A <cat-toy> train"
--num_validation_images=4
--validation_steps=100
PyTorch
Flax
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
export DATA_DIR="./cat"

accelerate launch textual_inversion.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" \
  --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 \
  --scale_lr \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --output_dir="textual_inversion_cat" \
  --push_to_hub

训练完成后,您可以像这样使用新训练的模型进行推理

PyTorch
Flax
from diffusers import StableDiffusionPipeline
import torch

pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipeline.load_textual_inversion("sd-concepts-library/cat-toy")
image = pipeline("A <cat-toy> train", num_inference_steps=50).images[0]
image.save("cat-train.png")

后续步骤

恭喜您训练了自己的文本反演模型!🎉 要了解有关如何使用新模型的更多信息,以下指南可能会有所帮助:

< > 在 GitHub 上更新