TRL 文档

去噪扩散策略优化

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

去噪扩散策略优化

目的

之前 DDPO 微调后

通过强化学习开始微调 Stable Diffusion

使用强化学习微调 Stable Diffusion 模型的机制大量使用了 HuggingFace 的 `diffusers` 库。之所以这样说,是因为入门需要对 `diffusers` 库的概念,主要是其中的两个——管道(pipelines)和调度器(schedulers),有一定的熟悉。开箱即用(`diffusers` 库)的情况下,没有适合强化学习微调的 `Pipeline` 或 `Scheduler` 实例。需要进行一些调整。

本库提供了一个管道接口,该接口必须实现才能与 `DDPOTrainer` 结合使用,`DDPOTrainer` 是使用强化学习微调 Stable Diffusion 的主要机制。**注意:目前仅支持 StableDiffusion 架构。** 本库提供了一个默认实现,您可以直接使用。假设默认实现足够或者为了快速启动,请参考本指南中的训练示例。

该接口的目的是将管道和调度器融合到一个对象中,从而最大限度地将所有约束集中在一个地方。设计该接口的目的是希望在本文撰写之时,能支持此仓库及其他地方示例之外的管道和调度器。此外,调度器步骤是此管道接口的一个方法,这可能看起来有些多余,因为原始调度器可以通过接口访问,但这是将调度器步骤输出限制为符合当前算法(DDPO)的输出类型的唯一方法。

要更详细地了解该接口及其关联的默认实现,请点击此处

请注意,默认实现包含 LoRA 实现路径和非 LoRA 实现路径。LoRA 标志默认启用,可以通过传递标志来关闭。基于 LORA 的训练速度更快,并且 LORA 相关的模型超参数对模型收敛的影响不像非 LORA 训练那样挑剔。

此外,还期望提供一个奖励函数和一个提示函数。奖励函数用于评估生成的图像,提示函数用于生成用于生成图像的提示。

ddpo.py 示例脚本入门

`ddpo.py` 脚本是使用 `DDPO` 训练器微调 Stable Diffusion 模型的工作示例。此示例明确配置了与配置对象 (`DDPOConfig`) 相关联的整体参数的一小部分。

**注意:** 建议使用一块 A100 GPU 来运行此示例。低于 A100 的显卡将无法运行此示例脚本,即使能通过相对较小的参数运行,结果也可能不尽如人意。

几乎每个配置参数都有一个默认值。用户只需要一个命令行标志参数即可启动和运行。用户需要拥有一个 huggingface 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace Hub。以下是要输入的 bash 命令以启动和运行:

python ddpo.py --hf_user_access_token <token>

要获取 `stable_diffusion_tuning.py` 的文档,请运行 `python stable_diffusion_tuning.py --help`

在配置训练器时(除了使用示例脚本的情况),请记住以下几点(代码也会为您检查这些):

  • 可配置的采样批大小 (`--ddpo_config.sample_batch_size=6`) 应大于或等于可配置的训练批大小 (`--ddpo_config.train_batch_size=3`)
  • 可配置的采样批量大小(`--ddpo_config.sample_batch_size=6`)必须能够被可配置的训练批量大小(`--ddpo_config.train_batch_size=3`)整除。
  • 可配置的采样批次大小 (`--ddpo_config.sample_batch_size=6`) 必须能被可配置的梯度累积步数 (`--ddpo_config.train_gradient_accumulation_steps=1`) 和可配置的加速器进程数同时整除。

设置图像日志钩子函数

期望函数以列表的形式接收一个列表列表:

[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]

并且 `image`、`prompt`、`prompt_metadata`、`rewards`、`reward_metadata` 都是批处理的。列表列表中的最后一个列表表示最后一个样本批次。您可能希望记录这一个。虽然您可以随意记录,但建议使用 `wandb` 或 `tensorboard`。

关键术语

  • `rewards`:奖励/分数是与生成的图像相关的数值,是指导强化学习过程的关键。
  • `reward_metadata`:奖励元数据是与奖励相关的元数据。可以将其理解为随奖励一起提供的额外信息负载。
  • `prompt`:提示是用于生成图像的文本。
  • `prompt_metadata`:提示元数据是与提示相关的元数据。当奖励模型包含 `FLAVA` 设置时,即生成的图像(请参见此处:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)预期附带问题和地面答案(链接到生成的图像),这种情况下的元数据将不为空。
  • `image`:由 Stable Diffusion 模型生成的图像

以下是使用 `wandb` 记录采样图像的示例代码。

# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
    # For the sake of this example, we only care about the last batch
    # hence we extract the last element of the list
    result = {}
    images, prompts, _, rewards, _ = image_data[-1]
    for i, image in enumerate(images):
        pil = Image.fromarray(
            (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        )
        pil = pil.resize((256, 256))
        result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
    accelerate_logger.log_images(
        result,
        step=global_step,
    )

使用微调模型

假设您已完成所有 epoch 并将模型推送到 hub,您可以按如下方式使用微调模型:


import torch
from trl import DefaultDDPOStableDiffusionPipeline

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
    image.save(f"{prompt}.png")

鸣谢

这项工作深受此处的仓库以及相关论文《通过强化学习训练扩散模型》(作者:Kevin Black、Michael Janner、Yilan Du、Ilya Kostrikov、Sergey Levine)此处的影响。

DDPOTrainer

trl.DDPOTrainer

< >

( config: DDPOConfig reward_function: typing.Callable[[torch.Tensor, tuple[str], tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )

参数

  • **config** (`DDPOConfig`) — `DDPOTrainer` 的配置对象。请查看 `PPOConfig` 文档以获取更多详细信息。
  • **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) — 要使用的奖励函数 —
  • **prompt_function** (Callable[[], tuple[str, Any]]) — 用于生成指导模型的提示的函数 —
  • **sd_pipeline** (`DDPOStableDiffusionPipeline`) — 用于训练的 Stable Diffusion 管道。 —
  • **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) — 用于记录图像的钩子 —

DDPOTrainer 使用深度扩散策略优化来优化扩散模型。请注意,此训练器深受此处工作的影响:https://github.com/kvablack/ddpo-pytorch。目前仅支持基于 Stable Diffusion 的管道

计算损失

< >

( 潜空间向量 时间步 下一潜空间向量 对数概率 优势 嵌入 )

参数

  • **latents** (torch.Tensor) — 从扩散模型采样的潜空间向量,形状:[batch_size, num_channels_latents, height, width]
  • **timesteps** (torch.Tensor) — 从扩散模型采样的时间步,形状:[batch_size]
  • **next_latents** (torch.Tensor) — 从扩散模型采样的下一个潜在变量,形状:[batch_size, num_channels_latents, height, width]
  • **log_probs** (torch.Tensor) — 潜在变量的对数概率,形状:[batch_size]
  • **advantages** (torch.Tensor) — 潜在变量的优势,形状:[batch_size]
  • **embeds** (torch.Tensor) — 提示的嵌入,形状:[2*batch_size 或 batch_size, ...] 注意:“或”是因为如果 train_cfg 为 True,则预期负面提示会与嵌入连接。

计算一批解包样本的损失

创建模型卡片

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

参数

  • **model_name** (`str` 或 `None`,*可选*,默认为 `None`) — 模型名称。
  • **dataset_name** (`str` 或 `None`,*可选*,默认为 `None`) — 用于训练的数据集名称。
  • **tags** (`str`、`list[str]` 或 `None`,*可选*,默认为 `None`) — 与模型卡关联的标签。

使用 Trainer 可用的信息创建模型卡片的草稿。

步骤

< >

( epoch: int global_step: int ) global_step (int)

参数

  • **epoch** (int) — 当前的 epoch。
  • **global_step** (int) — 当前的全局步。

返回

全局步数 (int)

更新后的全局步。

执行单步训练。

副作用

  • 模型权重已更新
  • 将统计数据记录到加速器跟踪器。
  • 如果 `self.image_samples_callback` 不为空,它将与 `prompt_image_pairs`、`global_step` 和加速器跟踪器一起被调用。

训练

< >

( epochs: typing.Optional[int] = None )

训练模型指定数量的 epoch

DDPOConfig

trl.DDPOConfig

< >

( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[str] = None tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 sample_batch_size: int = 1 sample_num_batches_per_epoch: int = 2 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.0003 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 train_num_inner_epochs: int = 1 train_cfg: bool = True train_adv_clip_max: float = 5.0 train_clip_range: float = 0.0001 train_timestep_fraction: float = 1.0 per_prompt_stat_tracking: bool = False per_prompt_stat_tracking_buffer_size: int = 16 per_prompt_stat_tracking_min_count: int = 16 async_reward_computation: bool = False max_workers: int = 2 negative_prompts: str = '' push_to_hub: bool = False )

参数

  • **exp_name** (`str`,*可选*,默认为 `os.path.basename(sys.argv[0])[ -- -len(".py")]`):此实验的名称(默认情况下是文件名,不带扩展名)。
  • **run_name** (`str`,*可选*,默认为 `""`) — 此运行的名称。
  • **seed** (`int`,*可选*,默认为 `0`) — 随机种子。
  • **log_with** (`Literal["wandb", "tensorboard"]]` 或 `None`,*可选*,默认为 `None`) — 使用 'wandb' 或 'tensorboard' 记录,请查看 https://huggingface.co/docs/accelerate/usage_guides/tracking 获取更多详细信息。
  • **tracker_kwargs** (`Dict`,*可选*,默认为 `{}`) — 跟踪器的关键字参数(例如 wandb_project)。
  • **accelerator_kwargs** (`Dict`,*可选*,默认为 `{}`) — 加速器的关键字参数。
  • **project_kwargs** (`Dict`,*可选*,默认为 `{}`) — 加速器项目配置的关键字参数(例如 `logging_dir`)。
  • **tracker_project_name** (`str`,*可选*,默认为 `"trl"`) — 用于跟踪的项目名称。
  • **logdir** (`str`,*可选*,默认为 `"logs"`) — 用于保存检查点的顶级日志目录。
  • num_epochs (int, optional, defaults to 100) — 训练的 epoch 数量。
  • save_freq (int, optional, defaults to 1) — 保存模型检查点之间的 epoch 数量。
  • num_checkpoint_limit (int, optional, defaults to 5) — 在覆盖旧检查点之前保留的检查点数量。
  • mixed_precision (str, optional, defaults to "fp16") — 混合精度训练。
  • allow_tf32 (bool, optional, defaults to True) — 允许在 Ampere GPU 上使用 tf32
  • resume_from (str, optional, defaults to "") — 从检查点恢复训练。
  • sample_num_steps (int, optional, defaults to 50) — 采样器推理步数。
  • sample_eta (float, optional, defaults to 1.0) — DDIM 采样器的 Eta 参数。
  • sample_guidance_scale (float, optional, defaults to 5.0) — 无分类器指导权重。
  • sample_batch_size (int, optional, defaults to 1) — 用于采样的批大小(每 GPU)。
  • sample_num_batches_per_epoch (int, optional, defaults to 2) — 每个 epoch 采样的批次数量。
  • train_batch_size (int, optional, defaults to 1) — 用于训练的批大小(每 GPU)。
  • train_use_8bit_adam (bool, optional, defaults to False) — 使用 bitsandbytes 中的 8 位 Adam 优化器。
  • train_learning_rate (float, optional, defaults to 3e-4) — 学习率。
  • train_adam_beta1 (float, optional, defaults to 0.9) — Adam beta1。
  • train_adam_beta2 (float, optional, defaults to 0.999) — Adam beta2。
  • train_adam_weight_decay (float, optional, defaults to 1e-4) — Adam 权重衰减。
  • train_adam_epsilon (float, optional, defaults to 1e-8) — Adam epsilon。
  • train_gradient_accumulation_steps (int, optional, defaults to 1) — 梯度累积步数。
  • train_max_grad_norm (float, optional, defaults to 1.0) — 梯度裁剪的最大梯度范数。
  • train_num_inner_epochs (int, optional, defaults to 1) — 每个外部 epoch 的内部 epoch 数量。
  • train_cfg (bool, optional, defaults to True) — 训练期间是否使用无分类器指导。
  • train_adv_clip_max (float, optional, defaults to 5.0) — 将优势剪辑到范围。
  • train_clip_range (float, optional, defaults to 1e-4) — PPO 裁剪范围。
  • train_timestep_fraction (float, optional, defaults to 1.0) — 训练时间步长的比例。
  • per_prompt_stat_tracking (bool, optional, defaults to False) — 是否为每个提示单独跟踪统计信息。
  • per_prompt_stat_tracking_buffer_size (int, optional, defaults to 16) — 为每个提示在缓冲区中存储的奖励值数量。
  • per_prompt_stat_tracking_min_count (int, optional, defaults to 16) — 在缓冲区中存储的最小奖励值数量。
  • async_reward_computation (bool, optional, defaults to False) — 是否异步计算奖励。
  • max_workers (int, optional, defaults to 2) — 用于异步奖励计算的最大工作器数量。
  • negative_prompts (str, optional, defaults to "") — 用作负面示例的提示的逗号分隔列表。
  • push_to_hub (bool, optional, defaults to False) — 是否将最终模型检查点推送到 Hub。

用于 DDPOTrainer 的配置类。

使用 HfArgumentParser,我们可以将此类别转换为可在命令行上指定的 argparse 参数。

< > 在 GitHub 上更新