Diffusers 文档

自注意力引导

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

自注意力引导

利用自注意力引导改进扩散模型样本质量由Susung Hong等人完成。

论文摘要如下:

去噪扩散模型 (DDMs) 因其出色的生成质量和多样性而受到关注。这一成功主要归因于类条件或文本条件扩散引导方法的使用,例如分类器和无分类器引导。本文提出了一个超越传统引导方法的更全面的视角。从这个广义视角出发,我们引入了新颖的无条件和无训练策略来提升生成图像的质量。作为一个简单的解决方案,模糊引导提高了中间样本对其精细信息和结构的适用性,使扩散模型能够以适度的引导尺度生成更高质量的样本。在此基础上,自注意力引导 (SAG) 利用扩散模型的中间自注意力图来增强其稳定性和效率。具体来说,SAG 在每次迭代中只对扩散模型关注的区域进行对抗性模糊,并据此进行引导。我们的实验结果表明,SAG 改进了包括 ADM、IDDPM、Stable Diffusion 和 DiT 在内的各种扩散模型的性能。此外,SAG 与传统引导方法相结合可以带来进一步的改进。

您可以在项目页面原始代码库找到有关自注意力引导的更多信息,并在演示笔记本中试用。

请务必查看调度器指南,了解如何在调度器速度和质量之间进行权衡,并参阅跨管道复用组件部分,了解如何高效地将相同组件加载到多个管道中。

StableDiffusionSAGPipeline

diffusers.StableDiffusionSAGPipeline

< >

( vae: AutoencoderKL text_encoder: CLIPTextModel tokenizer: CLIPTokenizer unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers safety_checker: StableDiffusionSafetyChecker feature_extractor: CLIPImageProcessor image_encoder: typing.Optional[transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection] = None requires_safety_checker: bool = True )

__call__

< >

( prompt: typing.Union[str, typing.List[str]] = None height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 50 guidance_scale: float = 7.5 sag_scale: float = 0.75 negative_prompt: typing.Union[str, typing.List[str], NoneType] = None num_images_per_prompt: typing.Optional[int] = 1 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None ip_adapter_image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor], NoneType] = None ip_adapter_image_embeds: typing.Optional[typing.List[torch.Tensor]] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True callback: typing.Optional[typing.Callable[[int, int, torch.Tensor], NoneType]] = None callback_steps: typing.Optional[int] = 1 cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None clip_skip: typing.Optional[int] = None ) StableDiffusionPipelineOutputtuple

参数

  • prompt (strList[str], 可选) — 用于引导图像生成的提示词或提示词列表。如果未定义,则需要传入 prompt_embeds
  • height (int, 可选, 默认为 self.unet.config.sample_size * self.vae_scale_factor) — 生成图像的高度(像素)。
  • width (int, 可选, 默认为 self.unet.config.sample_size * self.vae_scale_factor) — 生成图像的宽度(像素)。
  • num_inference_steps (int, 可选, 默认为 50) — 去噪步数。更多的去噪步数通常能生成更高质量的图像,但推理速度会变慢。
  • guidance_scale (float, 可选, 默认为 7.5) — 较高的引导尺度值鼓励模型生成与文本 prompt 密切相关的图像,但图像质量会降低。当 guidance_scale > 1 时启用引导尺度。
  • sag_scale (float, 可选, 默认为 0.75) — 介于 [0, 1.0] 之间,以获得更好的质量。
  • negative_prompt (strList[str], 可选) — 用于引导图像生成中不包含内容的提示词或提示词列表。如果未定义,则需要传入 negative_prompt_embeds。当不使用引导时(guidance_scale < 1),此参数将被忽略。
  • num_images_per_prompt (int, 可选, 默认为 1) — 每个提示词生成的图像数量。
  • eta (float, 可选, 默认为 0.0) — 对应于 DDIM 论文中的参数 eta (η)。仅适用于 DDIMScheduler,在其他调度器中将被忽略。
  • generator (torch.GeneratorList[torch.Generator], 可选) — 一个 torch.Generator 用于使生成具有确定性。
  • latents (torch.Tensor, 可选) — 从高斯分布中采样的预生成噪声潜在变量,用作图像生成的输入。可用于使用不同提示词调整同一生成。如果未提供,将使用提供的随机 generator 采样生成一个潜在变量张量。
  • prompt_embeds (torch.Tensor, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入(提示词权重)。如果未提供,文本嵌入将从 prompt 输入参数生成。
  • negative_prompt_embeds (torch.Tensor, 可选) — 预生成的负面文本嵌入。可用于轻松调整文本输入(提示词权重)。如果未提供,negative_prompt_embeds 将从 negative_prompt 输入参数生成。
  • ip_adapter_image — (PipelineImageInput, 可选): 与 IP 适配器配合使用的可选图像输入。
  • ip_adapter_image_embeds (List[torch.Tensor], 可选) — IP-Adapter 的预生成图像嵌入。如果未提供,嵌入将从 ip_adapter_image 输入参数计算。
  • output_type (str, 可选, 默认为 "pil") — 生成图像的输出格式。选择 PIL.Imagenp.array
  • return_dict (bool, 可选, 默认为 True) — 是否返回 StableDiffusionPipelineOutput,否则返回一个 tuple,其中第一个元素是生成的图像列表,第二个元素是布尔值列表,指示相应生成的图像是否包含“不适合工作”(nsfw) 内容。
  • callback (Callable, 可选) — 在推理过程中每 callback_steps 步调用的函数。该函数将使用以下参数调用:callback(step: int, timestep: int, latents: torch.Tensor)
  • callback_steps (int, 可选, 默认为 1) — 调用 callback 函数的频率。如果未指定,回调将在每个步骤调用。
  • cross_attention_kwargs (dict, 可选) — 一个 kwargs 字典,如果指定,将作为参数传递给 self.processor 中定义的 AttentionProcessor
  • clip_skip (int, 可选) — 计算提示词嵌入时要跳过的 CLIP 层数。值为 1 表示将使用倒数第二层的输出计算提示词嵌入。

返回

StableDiffusionPipelineOutputtuple

如果 return_dictTrue,则返回 StableDiffusionPipelineOutput,否则返回一个 tuple,其中第一个元素是生成的图像列表,第二个元素是布尔值列表,指示相应生成的图像是否包含“不适合工作”(nsfw) 内容。

用于生成的管道的调用函数。

示例

>>> import torch
>>> from diffusers import StableDiffusionSAGPipeline

>>> pipe = StableDiffusionSAGPipeline.from_pretrained(
...     "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")

>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, sag_scale=0.75).images[0]

encode_prompt

< >

( prompt device num_images_per_prompt do_classifier_free_guidance negative_prompt = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None lora_scale: typing.Optional[float] = None clip_skip: typing.Optional[int] = None )

参数

  • prompt (strList[str], 可选) — 要编码的提示词
  • device — (torch.device): torch 设备
  • num_images_per_prompt (int) — 每个提示词应生成的图片数量。
  • do_classifier_free_guidance (bool) — 是否使用分类器自由引导。
  • negative_prompt (strList[str], 可选) — 不用于引导图像生成的提示词。如果未定义,则必须传入 negative_prompt_embeds。当不使用引导时(即 guidance_scale 小于 1 时),该参数将被忽略。
  • prompt_embeds (torch.Tensor, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,文本嵌入将从 prompt 输入参数生成。
  • negative_prompt_embeds (torch.Tensor, 可选) — 预生成的负面文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,negative_prompt_embeds 将从 negative_prompt 输入参数生成。
  • lora_scale (float, 可选) — 如果加载了 LoRA 层,则将应用于文本编码器所有 LoRA 层的 LoRA 缩放因子。
  • clip_skip (int, 可选) — 计算提示词嵌入时要跳过的 CLIP 层数。值为 1 表示将使用倒数第二层的输出计算提示词嵌入。

将提示编码为文本编码器隐藏状态。

StableDiffusionOutput

diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput

< >

( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] nsfw_content_detected: typing.Optional[typing.List[bool]] )

参数

  • images (List[PIL.Image.Image]np.ndarray) — 长度为 batch_size 的去噪 PIL 图像列表或形状为 (batch_size, height, width, num_channels) 的 NumPy 数组。
  • nsfw_content_detected (List[bool]) — 指示相应生成的图像是否包含“不安全内容”(nsfw) 的列表,如果无法执行安全检查则为 None

Stable Diffusion 管道的输出类。

< > 在 GitHub 上更新