TRL 文档

降噪扩散策略优化

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

降噪扩散策略优化

原因

之前 DDPO 微调后

使用强化学习开始 Stable Diffusion 微调

使用强化学习微调 Stable Diffusion 模型的机制大量使用 HuggingFace 的 diffusers 库。之所以要说明这一点,是因为入门需要对 diffusers 库的概念有一定的了解,主要是两个概念 - 管道和调度器。在开箱即用时(diffusers 库),没有适合使用强化学习进行微调的 PipelineScheduler 实例。需要进行一些调整。

这个库提供了一个管道接口,需要实现这个接口才能与 DDPOTrainer 一起使用,DDPOTrainer 是使用强化学习微调 Stable Diffusion 的主要机制。注意:目前只支持 StableDiffusion 架构。有一个默认的接口实现,您可以开箱即用。假设默认实现足够用,或者为了让事情顺利进行,请参考本指南附带的训练示例。

该接口的目的是将管道和调度器融合到一个对象中,这样可以最大程度地减少将约束集中在一个地方的需要。该接口的目的是希望在本文档撰写时,能够满足除本存储库中以及其他地方的示例之外的管道和调度器。此外,调度器步骤是该管道接口的一种方法,这在可以通过接口访问原始调度器的情况下似乎是多余的,但这是将调度器步骤输出约束为适合当前算法(DDPO)的输出类型的唯一方法。

要更详细地了解该接口及其相关默认实现,请访问 此处

请注意,默认实现具有 LoRA 实现路径和非 LoRA 基于实现路径。LoRA 标志默认情况下处于启用状态,可以通过传递标志来禁用它。基于 LoRA 的训练速度更快,与模型收敛相关的 LoRA 关联模型超参数并不像非 LoRA 基于的训练那样挑剔。

此外,还期望提供奖励函数和提示函数。奖励函数用于评估生成的图像,提示函数用于生成用于生成图像的提示。

使用示例/脚本/ddpo.py 入门

ddpo.py 脚本是使用 DDPO 训练器微调 Stable Diffusion 模型的工作示例。此示例明确配置了与配置对象 (DDPOConfig) 关联的总体参数的一小部分。

注意:建议使用一台 A100 GPU 来运行。低于 A100 的任何设备都无法运行此示例脚本,即使它通过相对较小的参数运行,结果也很可能很差。

几乎每个配置参数都有一个默认值。用户只需要一个命令行标志参数就可以启动和运行。用户应拥有一个 HuggingFace 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace 集线器。以下 bash 命令用于启动和运行

python ddpo.py --hf_user_access_token <token>

要获取 stable_diffusion_tuning.py 的文档,请运行 python stable_diffusion_tuning.py --help

以下是在配置训练器时(超出使用示例脚本的用例)需要牢记的事项(代码也会为您检查):

  • 可配置的样本批次大小 (--ddpo_config.sample_batch_size=6) 应大于或等于可配置的训练批次大小 (--ddpo_config.train_batch_size=3)
  • 可配置的样本批次大小 (--ddpo_config.sample_batch_size=6) 必须能够被可配置的训练批次大小 (--ddpo_config.train_batch_size=3) 整除。
  • 可配置的样本批次大小 (--ddpo_config.sample_batch_size=6) 必须能够被可配置的梯度累积步数 (--ddpo_config.train_gradient_accumulation_steps=1) 和可配置的加速器进程计数整除。

设置图像日志记录挂钩函数

期望函数得到一个列表列表,形式如下:

[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]

imagepromptprompt_metadatarewardsreward_metadata 都是批处理的。列表列表中的最后一个列表代表最后一个样本批次。您可能希望记录此批次。尽管您可以随意以任何方式记录,但建议使用 wandbtensorboard

关键术语

  • rewards:奖励/分数是与生成的图像相关联的数字,是引导 RL 过程的关键
  • reward_metadata:奖励元数据是与奖励相关联的元数据。可以将其视为与奖励一起提供的额外信息有效负载。
  • prompt:提示是用于生成图像的文本。
  • prompt_metadata:提示元数据是与提示相关联的元数据。当奖励模型包含 FLAVA 设置,并且期望生成的图像带有问题和基础答案(链接到生成的图像)时,此元数据将不为空(参见:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)。
  • image:Stable Diffusion 模型生成的图像。

以下给出了使用 wandb 记录采样图像的示例代码。

# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
    # For the sake of this example, we only care about the last batch
    # hence we extract the last element of the list
    result = {}
    images, prompts, _, rewards, _ = image_data[-1]
    for i, image in enumerate(images):
        pil = Image.fromarray(
            (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        )
        pil = pil.resize((256, 256))
        result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
    accelerate_logger.log_images(
        result,
        step=global_step,
    )

使用微调后的模型

假设您已经完成了所有 epochs 并将模型推送到集线器,您可以按如下方式使用微调后的模型


import torch
from trl import DefaultDDPOStableDiffusionPipeline

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
    image.save(f"{prompt}.png")

致谢

这项工作深受以下仓库的影响:此处以及相关的论文:Kevin Black、Michael Janner、Yilan Du、Ilya Kostrikov 和 Sergey Levine 撰写的《使用强化学习训练扩散模型》.

< > 在 GitHub 上更新