使用奖励反向传播对齐文本到图像扩散模型
原因
如果您的奖励函数是可微的,则直接从奖励模型反向传播梯度到扩散模型比使用 DDPO 等策略梯度算法效率更高(25 倍),并且计算效率更高。AlignProp 执行完整的反向传播,允许通过奖励反向传播更新去噪的早期步骤。
使用 examples/scripts/alignprop.py 中的示例开始
alignprop.py
脚本是使用 AlignProp
训练器微调 Stable Diffusion 模型的工作示例。此示例明确配置了与配置对象 (AlignPropConfig
) 关联的整体参数的一个小子集。
注意:建议使用一个 A100 GPU 来运行此脚本。对于较低的内存设置,请考虑将 truncated_backprop_rand 设置为 False。在默认设置下,这将使用 K=1 进行截断反向传播。
几乎每个配置参数都有默认值。只有一个命令行标志参数是用户启动和运行所需的参数。用户需要有一个 Huggingface 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace 中心。以下 bash 命令用于启动和运行
python alignprop.py --hf_user_access_token <token>
要获取 stable_diffusion_tuning.py
的文档,请运行 python stable_diffusion_tuning.py --help
在配置训练器时(超出使用示例脚本的使用案例),请牢记以下事项(代码也会为您检查):
- 可配置的随机截断范围 (
--alignprop_config.truncated_rand_backprop_minmax=(0,50)
) 中,第一个数字应等于或大于 0,而第二个数字应等于或小于扩散时间步数 (sample_num_steps) - 可配置的截断反向传播绝对步长(
--alignprop_config.truncated_backprop_timestep=49
)的值应小于扩散步数(sample_num_steps),只有在truncated_backprop_rand设置为False时才起作用。
设置图像日志挂钩函数
期望该函数接收一个包含以下键的字典
['image', 'prompt', 'prompt_metadata', 'rewards']
并且image
、prompt
、prompt_metadata
、rewards
都是批处理的。您可以根据需要进行日志记录,建议使用wandb
或tensorboard
。
关键术语
rewards
:奖励/分数是与生成的图像相关的数值,是引导强化学习过程的关键。prompt
:提示是用于生成图像的文本。prompt_metadata
:提示元数据是与提示相关的元数据。当奖励模型包含一个FLAVA
设置时,此元数据将不为空,在这种情况下,预期生成的图像会与问题和地面真实答案相关联(请参见此处:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)。image
:由Stable Diffusion模型生成的图像。
下面提供了使用wandb
记录采样图像的示例代码。
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
使用微调后的模型
假设您已经完成了所有轮次,并将您的模型推送到Hub,您可以按如下方式使用微调后的模型
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
鸣谢
这项工作深受以下仓库此处以及相关论文通过奖励反向传播对齐文本到图像扩散模型,作者:Mihir Prabhudesai、Anirudh Goyal、Deepak Pathak、Katerina Fragkiadaki的影响。
< > GitHub更新