扩散器文档

管道回调

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

管道回调

可以使用 callback_on_step_end 参数使用自定义定义的函数修改管道的去噪循环。回调函数在每个步骤结束时执行,并修改管道属性和变量以用于下一步。这对动态调整某些管道属性或修改张量变量非常有用。这种多功能性允许有趣的用例,例如在每个时间步更改提示嵌入、为提示嵌入分配不同的权重以及编辑指导比例。使用回调,您可以在不修改底层代码的情况下实现新功能!

🤗 Diffusers 目前仅支持 callback_on_step_end,但如果您有很酷的用例并且需要具有不同执行点的回调函数,请随时打开 功能请求

本指南将通过一些您可以使用回调实现的功能来演示回调的工作原理。

官方回调

我们提供一个回调列表,您可以将其插入现有管道并修改去噪循环。这是当前的官方回调列表

  • SDCFGCutoffCallback:对所有 SD 1.5 管道(包括文本到图像、图像到图像、修复和 ControlNet)禁用 CFG,并在特定步骤数后停止。
  • SDXLCFGCutoffCallback:对所有 SDXL 管道(包括文本到图像、图像到图像、修复和 ControlNet)禁用 CFG,并在特定步骤数后停止。
  • IPAdapterScaleCutoffCallback:对所有支持 IP-Adapter 的管道禁用 IP Adapter,并在特定步骤数后停止。

如果您想添加一个新的官方回调,请随时打开 功能请求提交 PR

要设置回调,您需要指定回调生效后的去噪步骤数。您可以使用以下两个参数之一来实现:

  • cutoff_step_ratio:表示步骤比率的浮点数。
  • cutoff_step_index:表示确切步骤数的整数。
import torch

from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.callbacks import SDXLCFGCutoffCallback


callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
# can also be used with cutoff_step_index
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)

prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"

generator = torch.Generator(device="cpu").manual_seed(2628670641)

out = pipeline(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
)

out.images[0].save("official_callback.png")
generated image of a sports car at the road
没有 SDXLCFGCutoffCallback
generated image of a sports car at the road with cfg callback
使用 SDXLCFGCutoffCallback

动态无分类器引导

动态无分类器引导 (CFG) 是一种功能,允许您在特定推理步骤后禁用 CFG,这可以帮助您节省计算,而性能损失最小。此回调函数应具有以下参数

  • pipeline(或管道实例)提供对重要属性的访问,例如 num_timestepsguidance_scale。您可以通过更新底层属性来修改这些属性。对于本示例,您将通过设置 pipeline._guidance_scale=0.0 来禁用 CFG。
  • step_indextimestep 告诉您您在去噪循环中的位置。使用 step_index 在达到 num_timesteps 的 40% 后关闭 CFG。
  • callback_kwargs 是一个字典,其中包含您可以在去噪循环期间修改的张量变量。它仅包含在 callback_on_step_end_tensor_inputs 参数中指定的变量,该参数传递给管道的 __call__ 方法。不同的管道可能使用不同的变量集,因此请检查管道的 _callback_tensor_inputs 属性以获取您可以修改的变量列表。一些常见变量包括 latentsprompt_embeds。对于此函数,在设置 guidance_scale=0.0 后更改 prompt_embeds 的批次大小,以使其正常工作。

您的回调函数应如下所示

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
        # adjust the batch_size of prompt_embeds according to guidance_scale
        if step_index == int(pipeline.num_timesteps * 0.4):
                prompt_embeds = callback_kwargs["prompt_embeds"]
                prompt_embeds = prompt_embeds.chunk(2)[-1]

                # update guidance_scale and prompt_embeds
                pipeline._guidance_scale = 0.0
                callback_kwargs["prompt_embeds"] = prompt_embeds
        return callback_kwargs

现在,您可以将回调函数传递给 callback_on_step_end 参数,并将 prompt_embeds 传递给 callback_on_step_end_tensor_inputs

import torch
from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(1)
out = pipeline(
    prompt,
    generator=generator,
    callback_on_step_end=callback_dynamic_cfg,
    callback_on_step_end_tensor_inputs=['prompt_embeds']
)

out.images[0].save("out_custom_cfg.png")

中断扩散过程

中断回调适用于 StableDiffusionPipelineStableDiffusionXLPipeline 的文本到图像、图像到图像和修复。

在使用 Diffusers 构建 UI 时,提前停止扩散过程很有用,因为这允许用户在对中间结果不满意时停止生成过程。您可以使用回调将此功能合并到您的管道中。

此回调函数应采用以下参数:pipelineitcallback_kwargs(必须返回)。将管道的 _interrupt 属性设置为 True,以便在特定步骤数后停止扩散过程。您也可以在回调中实现自己的自定义停止逻辑。

在本示例中,即使 num_inference_steps 设置为 50,扩散过程也会在 10 步后停止。

from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.enable_model_cpu_offload()
num_inference_steps = 50

def interrupt_callback(pipeline, i, t, callback_kwargs):
    stop_idx = 10
    if i == stop_idx:
        pipeline._interrupt = True

    return callback_kwargs

pipeline(
    "A photo of a cat",
    num_inference_steps=num_inference_steps,
    callback_on_step_end=interrupt_callback,
)

在每个生成步骤后显示图像

此提示由 asomoza 贡献。

通过在每个步骤后访问和将潜在变量转换为图像,在每个生成步骤后显示图像。潜在空间压缩为 128x128,因此图像也是 128x128,这对于快速预览很有用。

  1. 使用以下函数将 SDXL 潜在变量(4 个通道)转换为 RGB 张量(3 个通道),如 解释 SDXL 潜在空间 博客文章中所述。
def latents_to_rgb(latents):
    weights = (
        (60, -60, 25, -70),
        (60,  -5, 15, -50),
        (60,  10, -5, -35)
    )

    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
    rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
    image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
    image_array = image_array.transpose(1, 2, 0)

    return Image.fromarray(image_array)
  1. 创建一个函数来解码并将潜在变量保存为图像。
def decode_tensors(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]

    image = latents_to_rgb(latents)
    image.save(f"{step}.png")

    return callback_kwargs
  1. decode_tensors 函数传递给 callback_on_step_end 参数,以便在每个步骤后解码张量。您还需要指定要在 callback_on_step_end_tensor_inputs 参数中修改的内容,在本例中是潜在变量。
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")

image = pipeline(
    prompt="A croissant shaped like a cute bear.",
    negative_prompt="Deformed, ugly, bad anatomy",
    callback_on_step_end=decode_tensors,
    callback_on_step_end_tensor_inputs=["latents"],
).images[0]
步骤 0
步骤 19
步骤 29
步骤 39
步骤 49
< > 更新 on GitHub