TRL 文档

使用奖励反向传播对齐文本到图像扩散模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

使用奖励反向传播对齐文本到图像扩散模型

为什么

如果你的奖励函数是可微分的,那么直接从奖励模型反向传播梯度到扩散模型,比使用像 DDPO 这样的策略梯度算法,在样本和计算效率上都显著更高 (25 倍)。AlignProp 执行完整的随时间反向传播,这允许通过奖励反向传播更新去噪的早期步骤。

开始使用 examples/scripts/alignprop.py

alignprop.py 脚本是使用 AlignProp 训练器微调 Stable Diffusion 模型的一个工作示例。此示例显式配置了与 config 对象 (AlignPropConfig) 关联的总体参数的一个小子集。

注意: 建议使用一个 A100 GPU 来运行此程序。对于较低的内存设置,请考虑将 truncated_backprop_rand 设置为 False。使用默认设置,这将使用 K=1 进行截断反向传播。

几乎每个配置参数都有一个默认值。只有一个命令行标志参数是用户启动并运行所需提供的。用户应拥有一个 huggingface 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace Hub。以下 bash 命令用于启动并运行

python alignprop.py --hf_user_access_token <token>

要获取 stable_diffusion_tuning.py 的文档,请运行 python stable_diffusion_tuning.py --help

以下是在配置训练器时需要记住的事项(代码也会为您检查这些事项),通常情况下(超出使用示例脚本的用例)

  • 可配置的随机截断范围 (--alignprop_config.truncated_rand_backprop_minmax=(0,50)) 第一个数字应等于且大于 0,而第二个数字应等于或小于扩散时间步数 (sample_num_steps)
  • 可配置的截断反向传播绝对步长 (--alignprop_config.truncated_backprop_timestep=49) 该数字应小于扩散时间步数 (sample_num_steps),仅当 truncated_backprop_rand 设置为 False 时才重要

设置图像日志记录 Hook 函数

期望该函数接收一个字典,其中包含键

['image', 'prompt', 'prompt_metadata', 'rewards']

imagepromptprompt_metadatarewards 都是批处理的。您可以自由地记录,建议使用 wandbtensorboard

关键术语

  • rewards : 奖励/分数是与生成的图像相关的数值,是引导 RL 过程的关键
  • prompt : Prompt 是用于生成图像的文本
  • prompt_metadata : Prompt 元数据是与 Prompt 关联的元数据。当奖励模型包含 FLAVA 设置时,这种情况不会为空,在这种情况下,问题和地面答案(链接到生成的图像)与生成的图像一起被期望(参见此处:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
  • image : 由 Stable Diffusion 模型生成的图像

下面给出了使用 wandb 记录采样图像的示例代码。

# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
    # For the sake of this example, we only care about the last batch
    # hence we extract the last element of the list
    result = {}
    images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
    for i, image in enumerate(images):
        pil = Image.fromarray(
            (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        )
        pil = pil.resize((256, 256))
        result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
    accelerate_logger.log_images(
        result,
        step=global_step,
    )

使用微调后的模型

假设您已完成所有 epoch 并已将模型推送到 Hub,您可以按如下方式使用微调后的模型

from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")

pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
    image.save(f"dump/{prompt}.png")

致谢

这项工作深受 此处 的 repo 和相关论文 Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki 的使用奖励反向传播对齐文本到图像扩散模型 的影响。

< > 在 GitHub 上更新