TRL 文档
去噪扩散策略优化
并获得增强的文档体验
开始使用
去噪扩散策略优化
目的
之前 | DDPO 微调后 |
---|---|
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
通过强化学习开始微调 Stable Diffusion
使用强化学习微调 Stable Diffusion 模型的机制大量使用了 HuggingFace 的 `diffusers` 库。之所以这样说,是因为入门需要对 `diffusers` 库的概念,主要是其中的两个——管道(pipelines)和调度器(schedulers),有一定的熟悉。开箱即用(`diffusers` 库)的情况下,没有适合强化学习微调的 `Pipeline` 或 `Scheduler` 实例。需要进行一些调整。
本库提供了一个管道接口,该接口必须实现才能与 `DDPOTrainer` 结合使用,`DDPOTrainer` 是使用强化学习微调 Stable Diffusion 的主要机制。**注意:目前仅支持 StableDiffusion 架构。** 本库提供了一个默认实现,您可以直接使用。假设默认实现足够或者为了快速启动,请参考本指南中的训练示例。
该接口的目的是将管道和调度器融合到一个对象中,从而最大限度地将所有约束集中在一个地方。设计该接口的目的是希望在本文撰写之时,能支持此仓库及其他地方示例之外的管道和调度器。此外,调度器步骤是此管道接口的一个方法,这可能看起来有些多余,因为原始调度器可以通过接口访问,但这是将调度器步骤输出限制为符合当前算法(DDPO)的输出类型的唯一方法。
要更详细地了解该接口及其关联的默认实现,请点击此处
请注意,默认实现包含 LoRA 实现路径和非 LoRA 实现路径。LoRA 标志默认启用,可以通过传递标志来关闭。基于 LORA 的训练速度更快,并且 LORA 相关的模型超参数对模型收敛的影响不像非 LORA 训练那样挑剔。
此外,还期望提供一个奖励函数和一个提示函数。奖励函数用于评估生成的图像,提示函数用于生成用于生成图像的提示。
ddpo.py 示例脚本入门
`ddpo.py` 脚本是使用 `DDPO` 训练器微调 Stable Diffusion 模型的工作示例。此示例明确配置了与配置对象 (`DDPOConfig`) 相关联的整体参数的一小部分。
**注意:** 建议使用一块 A100 GPU 来运行此示例。低于 A100 的显卡将无法运行此示例脚本,即使能通过相对较小的参数运行,结果也可能不尽如人意。
几乎每个配置参数都有一个默认值。用户只需要一个命令行标志参数即可启动和运行。用户需要拥有一个 huggingface 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace Hub。以下是要输入的 bash 命令以启动和运行:
python ddpo.py --hf_user_access_token <token>
要获取 `stable_diffusion_tuning.py` 的文档,请运行 `python stable_diffusion_tuning.py --help`
在配置训练器时(除了使用示例脚本的情况),请记住以下几点(代码也会为您检查这些):
- 可配置的采样批大小 (`--ddpo_config.sample_batch_size=6`) 应大于或等于可配置的训练批大小 (`--ddpo_config.train_batch_size=3`)
- 可配置的采样批量大小(`--ddpo_config.sample_batch_size=6`)必须能够被可配置的训练批量大小(`--ddpo_config.train_batch_size=3`)整除。
- 可配置的采样批次大小 (`--ddpo_config.sample_batch_size=6`) 必须能被可配置的梯度累积步数 (`--ddpo_config.train_gradient_accumulation_steps=1`) 和可配置的加速器进程数同时整除。
设置图像日志钩子函数
期望函数以列表的形式接收一个列表列表:
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
并且 `image`、`prompt`、`prompt_metadata`、`rewards`、`reward_metadata` 都是批处理的。列表列表中的最后一个列表表示最后一个样本批次。您可能希望记录这一个。虽然您可以随意记录,但建议使用 `wandb` 或 `tensorboard`。
关键术语
- `rewards`:奖励/分数是与生成的图像相关的数值,是指导强化学习过程的关键。
- `reward_metadata`:奖励元数据是与奖励相关的元数据。可以将其理解为随奖励一起提供的额外信息负载。
- `prompt`:提示是用于生成图像的文本。
- `prompt_metadata`:提示元数据是与提示相关的元数据。当奖励模型包含 `FLAVA` 设置时,即生成的图像(请参见此处:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)预期附带问题和地面答案(链接到生成的图像),这种情况下的元数据将不为空。
- `image`:由 Stable Diffusion 模型生成的图像
以下是使用 `wandb` 记录采样图像的示例代码。
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
使用微调模型
假设您已完成所有 epoch 并将模型推送到 hub,您可以按如下方式使用微调模型:
import torch
from trl import DefaultDDPOStableDiffusionPipeline
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")
鸣谢
这项工作深受此处的仓库以及相关论文《通过强化学习训练扩散模型》(作者:Kevin Black、Michael Janner、Yilan Du、Ilya Kostrikov、Sergey Levine)此处的影响。
DDPOTrainer
类 trl.DDPOTrainer
< 来源 >( config: DDPOConfig reward_function: typing.Callable[[torch.Tensor, tuple[str], tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )
参数
- **config** (`DDPOConfig`) — `DDPOTrainer` 的配置对象。请查看 `PPOConfig` 文档以获取更多详细信息。
- **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) — 要使用的奖励函数 —
- **prompt_function** (Callable[[], tuple[str, Any]]) — 用于生成指导模型的提示的函数 —
- **sd_pipeline** (`DDPOStableDiffusionPipeline`) — 用于训练的 Stable Diffusion 管道。 —
- **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) — 用于记录图像的钩子 —
DDPOTrainer 使用深度扩散策略优化来优化扩散模型。请注意,此训练器深受此处工作的影响:https://github.com/kvablack/ddpo-pytorch。目前仅支持基于 Stable Diffusion 的管道
计算损失
< 来源 >( 潜空间向量 时间步 下一潜空间向量 对数概率 优势 嵌入 )
参数
- **latents** (torch.Tensor) — 从扩散模型采样的潜空间向量,形状:[batch_size, num_channels_latents, height, width]
- **timesteps** (torch.Tensor) — 从扩散模型采样的时间步,形状:[batch_size]
- **next_latents** (torch.Tensor) — 从扩散模型采样的下一个潜在变量,形状:[batch_size, num_channels_latents, height, width]
- **log_probs** (torch.Tensor) — 潜在变量的对数概率,形状:[batch_size]
- **advantages** (torch.Tensor) — 潜在变量的优势,形状:[batch_size]
- **embeds** (torch.Tensor) — 提示的嵌入,形状:[2*batch_size 或 batch_size, ...] 注意:“或”是因为如果 train_cfg 为 True,则预期负面提示会与嵌入连接。
计算一批解包样本的损失
创建模型卡片
< 来源 >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
使用 Trainer
可用的信息创建模型卡片的草稿。
步骤
< 来源 >( epoch: int global_step: int ) → global_step (int)
执行单步训练。
副作用
- 模型权重已更新
- 将统计数据记录到加速器跟踪器。
- 如果 `self.image_samples_callback` 不为空,它将与 `prompt_image_pairs`、`global_step` 和加速器跟踪器一起被调用。
训练模型指定数量的 epoch
DDPOConfig
类 trl.DDPOConfig
< 来源 >( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[str] = None tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 sample_batch_size: int = 1 sample_num_batches_per_epoch: int = 2 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.0003 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 train_num_inner_epochs: int = 1 train_cfg: bool = True train_adv_clip_max: float = 5.0 train_clip_range: float = 0.0001 train_timestep_fraction: float = 1.0 per_prompt_stat_tracking: bool = False per_prompt_stat_tracking_buffer_size: int = 16 per_prompt_stat_tracking_min_count: int = 16 async_reward_computation: bool = False max_workers: int = 2 negative_prompts: str = '' push_to_hub: bool = False )
参数
- **exp_name** (`str`,*可选*,默认为 `os.path.basename(sys.argv[0])[ -- -len(".py")]`):此实验的名称(默认情况下是文件名,不带扩展名)。
- **run_name** (`str`,*可选*,默认为 `""`) — 此运行的名称。
- **seed** (`int`,*可选*,默认为 `0`) — 随机种子。
- **log_with** (`Literal["wandb", "tensorboard"]]` 或 `None`,*可选*,默认为 `None`) — 使用 'wandb' 或 'tensorboard' 记录,请查看 https://huggingface.co/docs/accelerate/usage_guides/tracking 获取更多详细信息。
- **tracker_kwargs** (`Dict`,*可选*,默认为 `{}`) — 跟踪器的关键字参数(例如 wandb_project)。
- **accelerator_kwargs** (`Dict`,*可选*,默认为 `{}`) — 加速器的关键字参数。
- **project_kwargs** (`Dict`,*可选*,默认为 `{}`) — 加速器项目配置的关键字参数(例如 `logging_dir`)。
- **tracker_project_name** (`str`,*可选*,默认为 `"trl"`) — 用于跟踪的项目名称。
- **logdir** (`str`,*可选*,默认为 `"logs"`) — 用于保存检查点的顶级日志目录。
- num_epochs (
int
, optional, defaults to100
) — 训练的 epoch 数量。 - save_freq (
int
, optional, defaults to1
) — 保存模型检查点之间的 epoch 数量。 - num_checkpoint_limit (
int
, optional, defaults to5
) — 在覆盖旧检查点之前保留的检查点数量。 - mixed_precision (
str
, optional, defaults to"fp16"
) — 混合精度训练。 - allow_tf32 (
bool
, optional, defaults toTrue
) — 允许在 Ampere GPU 上使用tf32
。 - resume_from (
str
, optional, defaults to""
) — 从检查点恢复训练。 - sample_num_steps (
int
, optional, defaults to50
) — 采样器推理步数。 - sample_eta (
float
, optional, defaults to1.0
) — DDIM 采样器的 Eta 参数。 - sample_guidance_scale (
float
, optional, defaults to5.0
) — 无分类器指导权重。 - sample_batch_size (
int
, optional, defaults to1
) — 用于采样的批大小(每 GPU)。 - sample_num_batches_per_epoch (
int
, optional, defaults to2
) — 每个 epoch 采样的批次数量。 - train_batch_size (
int
, optional, defaults to1
) — 用于训练的批大小(每 GPU)。 - train_use_8bit_adam (
bool
, optional, defaults toFalse
) — 使用 bitsandbytes 中的 8 位 Adam 优化器。 - train_learning_rate (
float
, optional, defaults to3e-4
) — 学习率。 - train_adam_beta1 (
float
, optional, defaults to0.9
) — Adam beta1。 - train_adam_beta2 (
float
, optional, defaults to0.999
) — Adam beta2。 - train_adam_weight_decay (
float
, optional, defaults to1e-4
) — Adam 权重衰减。 - train_adam_epsilon (
float
, optional, defaults to1e-8
) — Adam epsilon。 - train_gradient_accumulation_steps (
int
, optional, defaults to1
) — 梯度累积步数。 - train_max_grad_norm (
float
, optional, defaults to1.0
) — 梯度裁剪的最大梯度范数。 - train_num_inner_epochs (
int
, optional, defaults to1
) — 每个外部 epoch 的内部 epoch 数量。 - train_cfg (
bool
, optional, defaults toTrue
) — 训练期间是否使用无分类器指导。 - train_adv_clip_max (
float
, optional, defaults to5.0
) — 将优势剪辑到范围。 - train_clip_range (
float
, optional, defaults to1e-4
) — PPO 裁剪范围。 - train_timestep_fraction (
float
, optional, defaults to1.0
) — 训练时间步长的比例。 - per_prompt_stat_tracking (
bool
, optional, defaults toFalse
) — 是否为每个提示单独跟踪统计信息。 - per_prompt_stat_tracking_buffer_size (
int
, optional, defaults to16
) — 为每个提示在缓冲区中存储的奖励值数量。 - per_prompt_stat_tracking_min_count (
int
, optional, defaults to16
) — 在缓冲区中存储的最小奖励值数量。 - async_reward_computation (
bool
, optional, defaults toFalse
) — 是否异步计算奖励。 - max_workers (
int
, optional, defaults to2
) — 用于异步奖励计算的最大工作器数量。 - negative_prompts (
str
, optional, defaults to""
) — 用作负面示例的提示的逗号分隔列表。 - push_to_hub (
bool
, optional, defaults toFalse
) — 是否将最终模型检查点推送到 Hub。
用于 DDPOTrainer 的配置类。
使用 HfArgumentParser
,我们可以将此类别转换为可在命令行上指定的 argparse 参数。