TRL 文档
去噪扩散策略优化
并获取增强的文档体验
开始使用
去噪扩散策略优化
为什么选择 DDPO?
之前 | DDPO 微调之后 |
---|---|
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
开始使用强化学习微调 Stable Diffusion
使用强化学习微调 Stable Diffusion 模型的机制大量使用了 HuggingFace 的 diffusers
库。 声明这一点的原因是,入门需要对 diffusers
库的概念有一定的熟悉度,主要是管道 (pipelines) 和调度器 (schedulers) 这两个概念。 直接使用 (diffusers
库),没有适用于强化学习微调的 Pipeline
或 Scheduler
实例。 需要进行一些调整。
此库提供了一个管道接口,需要实现该接口才能与 DDPOTrainer
一起使用,DDPOTrainer
是使用强化学习微调 Stable Diffusion 的主要机制。 注意:目前仅支持 StableDiffusion 架构。 您可以直接使用此接口的默认实现。 假设默认实现足够,或者为了快速上手,请参考本指南旁边的训练示例。
该接口的目的是将管道和调度器融合到一个对象中,从而在将约束全部放在一个地方方面实现最小化。 该接口的设计希望能够满足超出本存储库和撰写本文时其他地方的示例的管道和调度器。 此外,调度器步骤是此管道接口的一个方法,考虑到可以通过该接口访问原始调度器,这可能看起来是冗余的,但这是将调度器步骤输出约束为适合手头算法(DDPO)的输出类型的唯一方法。
要更详细地了解该接口和相关的默认实现,请访问 此处
请注意,默认实现具有基于 LoRA 的实现路径和非基于 LoRA 的实现路径。 LoRA 标志默认启用,可以通过传入标志来关闭。 基于 LORA 的训练速度更快,并且与模型收敛相关的 LORA 相关模型超参数不如非 LORA 训练那样挑剔。
此外,还需要提供奖励函数和提示函数。 奖励函数用于评估生成的图像,提示函数用于生成用于生成图像的提示。
开始使用 examples/scripts/ddpo.py
ddpo.py
脚本是使用 DDPO
训练器微调 Stable Diffusion 模型的工作示例。 此示例显式配置了与配置对象 (DDPOConfig
) 关联的总体参数的一小部分子集。
注意: 建议使用一个 A100 GPU 来运行此示例。 任何低于 A100 的 GPU 都将无法运行此示例脚本,即使使用相对较小的参数运行,结果也可能很差。
几乎每个配置参数都有默认值。 只有一个命令行标志参数是用户启动并运行所需的。 用户应拥有一个 huggingface 用户访问令牌,该令牌将用于在微调后将模型上传到 HuggingFace Hub。 要运行,请输入以下 bash 命令
python ddpo.py --hf_user_access_token <token>
要获取 stable_diffusion_tuning.py
的文档,请运行 python stable_diffusion_tuning.py --help
以下是在配置训练器时需要记住的事项(除了使用示例脚本的用例外)(代码也会为您检查这一点)
- 可配置的样本批次大小 (
--ddpo_config.sample_batch_size=6
) 应大于或等于可配置的训练批次大小 (--ddpo_config.train_batch_size=3
) - 可配置的样本批次大小 (
--ddpo_config.sample_batch_size=6
) 必须可被可配置的训练批次大小 (--ddpo_config.train_batch_size=3
) 整除 - 可配置的样本批次大小 (
--ddpo_config.sample_batch_size=6
) 必须可被可配置的梯度累积步数 (--ddpo_config.train_gradient_accumulation_steps=1
) 和可配置的加速器进程计数整除
设置图像日志记录 Hook 函数
期望该函数接收一个列表的列表,形式如下
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
image
、prompt
、prompt_metadata
、rewards
、reward_metadata
都是批处理的。 列表的列表中的最后一个列表表示最后一个样本批次。 您可能想要记录这一个。 虽然您可以随意记录,但建议使用 wandb
或 tensorboard
。
关键术语
rewards
: 奖励/分数是与生成的图像相关的数值,是指导 RL 过程的关键reward_metadata
: 奖励元数据是与奖励关联的元数据。 可以将其视为与奖励一起传递的额外信息负载prompt
: 提示是用于生成图像的文本prompt_metadata
: 提示元数据是与提示关联的元数据。 当奖励模型包含FLAVA
设置时,这种情况不会为空,其中问题和地面答案(链接到生成的图像)与生成的图像一起被期望(请参阅此处: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)image
: 由 Stable Diffusion 模型生成的图像
下面给出了使用 wandb
记录采样图像的示例代码。
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
使用微调后的模型
假设您已完成所有 epoch 并且已将您的模型推送到 Hub,您可以按如下方式使用微调后的模型
import torch
from trl import DefaultDDPOStableDiffusionPipeline
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")
致谢
这项工作深受 此处 的 repo 以及相关论文 Training Diffusion Models with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine 的影响。
DDPOTrainer
class trl.DDPOTrainer
< source >( config: DDPOConfig reward_function: typing.Callable[[torch.Tensor, tuple[str], tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )
参数
- **config** (
DDPOConfig
) — DDPOTrainer 的配置对象。 有关更多详细信息,请查看PPOConfig
的文档。 - **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) — 要使用的奖励函数 —
- **prompt_function** (Callable[[], tuple[str, Any]]) — 用于生成提示以指导模型的功能 —
- **sd_pipeline** (
DDPOStableDiffusionPipeline
) — 用于训练的 Stable Diffusion 管道。 — - **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) — 用于记录图像的 Hook —
DDPOTrainer 使用深度扩散策略优化来优化扩散模型。请注意,此训练器在很大程度上受到了以下工作的启发:https://github.com/kvablack/ddpo-pytorch。目前仅支持基于 Stable Diffusion 的 pipelines。
calculate_loss
< source >( latents timesteps next_latents log_probs advantages embeds )
参数
- latents (torch.Tensor) — 从扩散模型中采样的潜在变量,形状:[batch_size, num_channels_latents, height, width]
- timesteps (torch.Tensor) — 从扩散模型中采样的时间步,形状:[batch_size]
- next_latents (torch.Tensor) — 从扩散模型中采样的下一个潜在变量,形状:[batch_size, num_channels_latents, height, width]
- log_probs (torch.Tensor) — 潜在变量的对数概率,形状:[batch_size]
- advantages (torch.Tensor) — 潜在变量的优势,形状:[batch_size]
- embeds (torch.Tensor) — 提示词的嵌入,形状:[2*batch_size 或 batch_size, …] 注意:“或”是因为如果 train_cfg 为 True,则期望负面提示词与嵌入连接
计算一个批次解包样本的损失
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
使用 Trainer
可用的信息创建模型卡的草稿。
step
< source >( epoch: int global_step: int ) → global_step (int)
执行单步训练。
副作用
- 模型权重已更新
- 将统计信息记录到加速器跟踪器。
- 如果
self.image_samples_callback
不是 None,它将使用 prompt_image_pairs、global_step 和加速器跟踪器进行调用。
训练模型给定的 epoch 数
DDPOConfig
class trl.DDPOConfig
< source >( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[str] = None tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 sample_batch_size: int = 1 sample_num_batches_per_epoch: int = 2 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.0003 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 train_num_inner_epochs: int = 1 train_cfg: bool = True train_adv_clip_max: float = 5.0 train_clip_range: float = 0.0001 train_timestep_fraction: float = 1.0 per_prompt_stat_tracking: bool = False per_prompt_stat_tracking_buffer_size: int = 16 per_prompt_stat_tracking_min_count: int = 16 async_reward_computation: bool = False max_workers: int = 2 negative_prompts: str = '' push_to_hub: bool = False )
参数
- exp_name (
str
, 可选, 默认为os.path.basename(sys.argv[0])[ -- -len(".py")]
): 此实验的名称(默认情况下是文件名,不带扩展名)。 - run_name (
str
, 可选, 默认为""
) — 此运行的名称。 - seed (
int
, 可选, 默认为0
) — 随机种子。 - log_with (
Literal["wandb", "tensorboard"]]
或None
, 可选, 默认为None
) — 使用 ‘wandb’ 或 ‘tensorboard’ 进行日志记录,请查看 https://huggingface.co/docs/accelerate/usage_guides/tracking 了解更多详情。 - project_kwargs (
Dict
, 可选, 默认为{}
) — 加速器项目配置的关键字参数 (例如logging_dir
)。 - tracker_project_name (
str
, 可选, 默认为"trl"
) — 用于跟踪的项目名称。 - logdir (
str
, 可选, 默认为"logs"
) — 用于检查点保存的顶层日志目录。 - num_epochs (
int
, 可选, 默认为100
) — 训练的轮数。 - save_freq (
int
, 可选, 默认为1
) — 保存模型检查点之间的轮数。 - num_checkpoint_limit (
int
, 可选, 默认为5
) — 在覆盖旧检查点之前保留的检查点数量。 - mixed_precision (
str
, 可选, 默认为"fp16"
) — 混合精度训练。 - allow_tf32 (
bool
, 可选, 默认为True
) — 允许在 Ampere GPU 上使用tf32
。 - resume_from (
str
, 可选, 默认为""
) — 从检查点继续训练。 - sample_num_steps (
int
, 可选, 默认为50
) — 采样器推理步数。 - sample_eta (
float
, 可选, 默认为1.0
) — DDIM 采样器的 Eta 参数。 - sample_guidance_scale (
float
, 可选, 默认为5.0
) — 无分类器引导权重。 - sample_batch_size (
int
, 可选, 默认为1
) — 用于采样的批大小(每个 GPU)。 - sample_num_batches_per_epoch (
int
, 可选, 默认为2
) — 每个 epoch 采样的批次数。 - train_batch_size (
int
, 可选, 默认为1
) — 用于训练的批大小(每个 GPU)。 - train_use_8bit_adam (
bool
, 可选, 默认为False
) — 使用来自 bitsandbytes 的 8 位 Adam 优化器。 - train_learning_rate (
float
, 可选, 默认为3e-4
) — 学习率。 - train_adam_beta1 (
float
, 可选, 默认为0.9
) — Adam beta1。 - train_adam_beta2 (
float
, 可选, 默认为0.999
) — Adam beta2。 - train_adam_weight_decay (
float
, 可选, 默认为1e-4
) — Adam 权重衰减。 - train_adam_epsilon (
float
, 可选, 默认为1e-8
) — Adam epsilon。 - train_gradient_accumulation_steps (
int
, 可选, 默认为1
) — 梯度累积步数。 - train_max_grad_norm (
float
, 可选, 默认为1.0
) — 用于梯度裁剪的最大梯度范数。 - train_num_inner_epochs (
int
, 可选, 默认为1
) — 每个外部 epoch 的内部 epoch 数。 - train_cfg (
bool
, 可选, 默认为True
) — 训练期间是否使用无分类器引导。 - train_adv_clip_max (
float
, 可选, 默认为5.0
) — 将优势裁剪到该范围。 - train_clip_range (
float
, 可选, 默认为1e-4
) — PPO 裁剪范围。 - train_timestep_fraction (
float
, 可选, 默认为1.0
) — 用于训练的时间步分数。 - per_prompt_stat_tracking (
bool
, 可选, 默认为False
) — 是否分别跟踪每个 prompt 的统计信息。 - per_prompt_stat_tracking_buffer_size (
int
, 可选, 默认为16
) — 在缓冲区中为每个 prompt 存储的奖励值数量。 - per_prompt_stat_tracking_min_count (
int
, 可选, 默认为16
) — 在缓冲区中存储的最小奖励值数量。 - async_reward_computation (
bool
, 可选, 默认为False
) — 是否异步计算奖励。 - max_workers (
int
, 可选, 默认为2
) — 用于异步奖励计算的最大工作进程数。 - negative_prompts (
str
, 可选, 默认为""
) — 用逗号分隔的提示词列表,用作负面示例。 - push_to_hub (
bool
, 可选, 默认为False
) — 是否将最终模型检查点推送到 Hub。
DDPOTrainer 的配置类。
使用 HfArgumentParser
,我们可以将此类转换为可以在命令行中指定的 argparse 参数。