Diffusers 文档
SANA-Sprint
并获得增强的文档体验
开始使用
SANA-Sprint
SANA-Sprint:带连续时间一致性蒸馏的一步扩散,由NVIDIA、MIT HAN Lab和Hugging Face的Junsong Chen、Shuchen Xue、Yuyang Zhao、Jincheng Yu、Sayak Paul、Junyu Chen、Han Cai、Enze Xie、Song Han撰写
论文摘要如下:
本文介绍了SANA-Sprint,一种用于超快速文本到图像(T2I)生成的有效扩散模型。SANA-Sprint建立在预训练基础模型之上,并辅以混合蒸馏,将推理步骤从20步显著减少到1-4步。我们引入了三项关键创新:(1)我们提出了一种免训练方法,将预训练的流匹配模型用于连续时间一致性蒸馏(sCM),消除了从头开始训练的高昂成本,并实现了高训练效率。我们的混合蒸馏策略将sCM与潜在对抗蒸馏(LADD)结合起来:sCM确保与教师模型对齐,而LADD增强了单步生成保真度。(2)SANA-Sprint是一个统一的步长自适应模型,可在1-4步内实现高质量生成,消除了针对特定步长的训练,提高了效率。(3)我们将ControlNet与SANA-Sprint集成,实现实时交互式图像生成,为用户交互提供即时视觉反馈。SANA-Sprint在速度-质量权衡方面建立了新的帕累托前沿,仅用1步就实现了最先进的性能,FID为7.59,GenEval为0.74——优于FLUX-schnell(FID为7.94 / GenEval为0.71),同时速度快10倍(H100上0.1秒 vs 1.1秒)。它还在H100上实现了1024×1024图像的0.1秒(T2I)和0.25秒(ControlNet)延迟,以及在RTX 4090上0.31秒(T2I)的延迟,展示了其卓越的效率和AI驱动的消费级应用(AIPC)的潜力。代码和预训练模型将开源。
此管道由lawrence-cj、shuchen Xue和Enze Xie贡献。原始代码库可在此处找到。原始权重可在hf.co/Efficient-Large-Model下找到。
可用模型
模型 | 推荐数据类型 |
---|---|
Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers | torch.bfloat16 |
Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers | torch.bfloat16 |
更多信息请参考此集合。
注意:推荐的数据类型是指Transformer权重。文本编码器必须保持为torch.bfloat16
,VAE权重必须保持为torch.bfloat16
或torch.float32
,模型才能正常工作。请参阅下面的推理示例,了解如何使用推荐的数据类型加载模型。
量化
量化有助于通过以较低精度数据类型存储模型权重来减少大型模型的内存需求。但是,量化对视频质量的影响可能因视频模型而异。
请参阅量化概述,了解更多支持的量化后端以及如何选择支持您用例的量化后端。下面的示例演示了如何使用bitsandbytes加载量化的SanaSprintPipeline进行推理。
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipeline = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.bfloat16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")
设置max_timesteps
用户可以调整max_timesteps
的值以实验生成输出的视觉质量。默认的max_timesteps
值是通过推理时间搜索过程获得的。有关其更多详细信息,请查看论文。
图像到图像
SanaSprintImg2ImgPipeline是一个用于图像到图像生成的管道。它接收一张输入图像和一个提示,并根据输入图像和提示生成一张新图像。
import torch
from diffusers import SanaSprintImg2ImgPipeline
from diffusers.utils.loading_utils import load_image
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
)
pipe = SanaSprintImg2ImgPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = pipe(
prompt="a cute pink bear",
image=image,
strength=0.5,
height=832,
width=480
).images[0]
image.save("output.png")
SanaSprintPipeline
类 diffusers.SanaSprintPipeline
< 源 >( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: AutoencoderDC transformer: SanaTransformer2DModel scheduler: DPMSolverMultistepScheduler )
使用SANA-Sprint进行文本到图像生成的管道。
__call__
< 源 >( prompt: typing.Union[str, typing.List[str]] = None num_inference_steps: int = 2 timesteps: typing.List[int] = None max_timesteps: float = 1.5708 intermediate_timesteps: float = 1.3 guidance_scale: float = 4.5 num_images_per_prompt: typing.Optional[int] = 1 height: int = 1024 width: int = 1024 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] ) → SanaPipelineOutput 或 元组
参数
- prompt (
str
或List[str]
, 可选) — 用于引导图像生成的提示词或提示词列表。如果未定义,则必须传入prompt_embeds
。 - num_inference_steps (
int
, 可选, 默认为 20) — 去噪步骤的数量。更多的去噪步骤通常会带来更高质量的图像,但推理速度会变慢。 - max_timesteps (
float
, 可选, 默认为 1.57080) — SCM调度器中使用的最大时间步值。 - intermediate_timesteps (
float
, 可选, 默认为 1.3) — SCM调度器中使用的中间时间步值(仅在num_inference_steps=2时使用)。 - timesteps (
List[int]
, 可选) — 用于去噪过程的自定义时间步,适用于其set_timesteps
方法支持timesteps
参数的调度器。如果未定义,将使用传入num_inference_steps
时的默认行为。必须按降序排列。 - guidance_scale (
float
, 可选, 默认为 4.5) — 在Classifier-Free Diffusion Guidance中定义的引导比例。guidance_scale
被定义为Imagen Paper的公式2中的w
。通过设置guidance_scale > 1
启用引导比例。较高的引导比例鼓励生成与文本prompt
紧密相关的图像,但通常会牺牲图像质量。 - num_images_per_prompt (
int
, 可选, 默认为 1) — 每个提示词生成的图像数量。 - height (
int
, 可选, 默认为 self.unet.config.sample_size) — 生成图像的高度(像素)。 - width (
int
, 可选, 默认为 self.unet.config.sample_size) — 生成图像的宽度(像素)。 - eta (
float
, 可选, 默认为 0.0) — 对应于DDIM论文中的参数eta (η):https://huggingface.co/papers/2010.02502。仅适用于schedulers.DDIMScheduler,对其他调度器将被忽略。 - generator (
torch.Generator
或List[torch.Generator]
, 可选) — 一个或多个torch生成器,用于使生成过程具有确定性。 - latents (
torch.Tensor
, 可选) — 预生成的噪声潜变量,从高斯分布中采样,用作图像生成的输入。可用于使用不同的提示调整相同的生成。如果未提供,将使用提供的随机generator
采样生成一个潜变量张量。 - prompt_embeds (
torch.Tensor
, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,将从prompt
输入参数生成文本嵌入。 - prompt_attention_mask (
torch.Tensor
, 可选) — 预生成的文本嵌入的注意力掩码。 - output_type (
str
, 可选, 默认为"pil"
) — 生成图像的输出格式。在PIL:PIL.Image.Image
或np.array
之间选择。 - return_dict (
bool
, 可选, 默认为True
) — 是否返回~pipelines.stable_diffusion.IFPipelineOutput
而不是普通元组。 - attention_kwargs — 一个kwargs字典,如果指定,将作为
self.processor
中定义的AttentionProcessor
的参数传递,详见diffusers.models.attention_processor。 - clean_caption (
bool
, 可选, 默认为True
) — 是否在创建嵌入之前清理标题。需要安装beautifulsoup4
和ftfy
。如果未安装依赖项,嵌入将从原始提示创建。 - use_resolution_binning (
bool
默认为True
) — 如果设置为True
,请求的高度和宽度将首先使用ASPECT_RATIO_1024_BIN
映射到最接近的分辨率。生成的潜在变量解码为图像后,再将其大小调整回请求的分辨率。对于生成非正方形图像很有用。 - callback_on_step_end (
Callable
, 可选) — 在推理过程中每个去噪步骤结束时调用的函数。该函数将使用以下参数调用:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
。callback_kwargs
将包含callback_on_step_end_tensor_inputs
中指定的所有张量列表。 - callback_on_step_end_tensor_inputs (
List
, 可选) —callback_on_step_end
函数的张量输入列表。列表中指定的张量将作为callback_kwargs
参数传递。您只能包含管道类._callback_tensor_inputs
属性中列出的变量。 - max_sequence_length (
int
默认为300
) — 用于prompt
的最大序列长度。 - complex_human_instruction (
List[str]
, 可选) — 复杂人类注意力的指令:https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55。
返回
SanaPipelineOutput 或 tuple
如果 return_dict
为 True
,则返回 SanaPipelineOutput,否则返回一个 tuple
,其中第一个元素是生成的图像列表
调用管道进行生成时调用的函数。
示例
>>> import torch
>>> from diffusers import SanaSprintPipeline
>>> pipe = SanaSprintPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
>>> image[0].save("output.png")
禁用切片 VAE 解码。如果之前启用了 enable_vae_slicing
,此方法将返回一步计算解码。
禁用平铺 VAE 解码。如果之前启用了 enable_vae_tiling
,此方法将恢复一步计算解码。
启用切片 VAE 解码。启用此选项后,VAE 会将输入张量分片,分步计算解码。这有助于节省一些内存并允许更大的批次大小。
启用平铺 VAE 解码。启用此选项后,VAE 将把输入张量分割成瓦片,分多步计算编码和解码。这对于节省大量内存和处理更大的图像非常有用。
encode_prompt
< source >( prompt: typing.Union[str, typing.List[str]] num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None lora_scale: typing.Optional[float] = None )
参数
- prompt (
str
或List[str]
, 可选) — 要编码的提示词 - num_images_per_prompt (
int
, 可选, 默认为 1) — 每个提示词应生成的图像数量 - device — (
torch.device
, 可选): 放置结果嵌入的 torch 设备 - prompt_embeds (
torch.Tensor
, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,将从prompt
输入参数生成文本嵌入。 - clean_caption (
bool
, 默认为False
) — 如果为True
,函数将在编码前预处理并清理提供的字幕。 - max_sequence_length (
int
, 默认为 300) — 用于提示词的最大序列长度。 - complex_human_instruction (
list[str]
, 默认为complex_human_instruction
) — 如果complex_human_instruction
不为空,函数将使用复杂的“人类指令”作为提示词。
将提示编码为文本编码器隐藏状态。
SanaSprintImg2ImgPipeline
class diffusers.SanaSprintImg2ImgPipeline
< source >( tokenizer: typing.Union[transformers.models.gemma.tokenization_gemma.GemmaTokenizer, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast] text_encoder: Gemma2PreTrainedModel vae: AutoencoderDC transformer: SanaTransformer2DModel scheduler: DPMSolverMultistepScheduler )
使用SANA-Sprint进行文本到图像生成的管道。
__call__
< source >( prompt: typing.Union[str, typing.List[str]] = None num_inference_steps: int = 2 timesteps: typing.List[int] = None max_timesteps: float = 1.5708 intermediate_timesteps: float = 1.3 guidance_scale: float = 4.5 image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] = None strength: float = 0.6 num_images_per_prompt: typing.Optional[int] = 1 height: int = 1024 width: int = 1024 eta: float = 0.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True clean_caption: bool = False use_resolution_binning: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 300 complex_human_instruction: typing.List[str] = ["Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.', '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.', 'Here are examples of how to transform or refine prompts:', '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.', '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.', 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:', 'User Prompt: '] ) → SanaPipelineOutput 或 tuple
参数
- prompt (
str
或List[str]
, 可选) — 用于引导图像生成的提示词。如果未定义,则必须传入prompt_embeds
。 - num_inference_steps (
int
, 可选, 默认为 20) — 去噪步数。更多的去噪步数通常会带来更高质量的图像,但推理速度会变慢。 - max_timesteps (
float
, 可选, 默认为 1.57080) — SCM 调度器中使用的最大时间步值。 - intermediate_timesteps (
float
, 可选, 默认为 1.3) — SCM 调度器中使用的中间时间步值(仅当 num_inference_steps=2 时使用)。 - timesteps (
List[int]
, 可选) — 用于去噪过程的自定义时间步,适用于其set_timesteps
方法支持timesteps
参数的调度器。如果未定义,将使用传入num_inference_steps
时的默认行为。必须按降序排列。 - guidance_scale (
float
, 可选, 默认为 4.5) — 如 Classifier-Free Diffusion Guidance 中定义的引导比例。guidance_scale
定义为 Imagen Paper 方程 2 中的w
。通过设置guidance_scale > 1
启用引导比例。更高的引导比例鼓励生成与文本prompt
密切相关的图像,通常以牺牲较低图像质量为代价。 - num_images_per_prompt (
int
, 可选, 默认为 1) — 每个提示词要生成的图像数量。 - height (
int
, 可选, 默认为 self.unet.config.sample_size) — 生成图像的像素高度。 - width (
int
, 可选, 默认为 self.unet.config.sample_size) — 生成图像的像素宽度。 - eta (
float
, 可选, 默认为 0.0) — 对应于 DDIM 论文中的参数 eta (η):https://arxiv.org/abs/2010.02502。仅适用于 schedulers.DDIMScheduler,对其他调度器将被忽略。 - generator (
torch.Generator
或List[torch.Generator]
, 可选) — 一个或多个 torch 生成器,用于使生成具有确定性。 - latents (
torch.Tensor
, 可选) — 预先生成的噪声潜像,从高斯分布中采样,用作图像生成的输入。可用于通过不同提示词调整同一生成。如果未提供,将使用提供的随机generator
采样生成潜像张量。 - prompt_embeds (
torch.Tensor
, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,文本嵌入将从prompt
输入参数生成。 - prompt_attention_mask (
torch.Tensor
, 可选) — 文本嵌入的预生成注意力掩码。 - output_type (
str
, 可选, 默认为"pil"
) — 生成图像的输出格式。在 PIL:PIL.Image.Image
或np.array
之间选择。 - return_dict (
bool
, 可选, 默认为True
) — 是否返回~pipelines.stable_diffusion.IFPipelineOutput
而不是普通元组。 - attention_kwargs — 如果指定,将作为 kwargs 字典传递给 diffusers.models.attention_processor 中
self.processor
下定义的AttentionProcessor
。 - clean_caption (
bool
, 可选, 默认为True
) — 是否在创建嵌入前清理字幕。需要安装beautifulsoup4
和ftfy
。如果未安装依赖项,则将从原始提示词创建嵌入。 - use_resolution_binning (
bool
默认为True
) — 如果设置为True
,请求的高度和宽度将首先使用ASPECT_RATIO_1024_BIN
映射到最接近的分辨率。生成潜像解码为图像后,将它们的大小调整回请求的分辨率。对于生成非方形图像很有用。 - callback_on_step_end (
Callable
, 可选) — 在推理过程中每个去噪步骤结束时调用的函数。该函数将使用以下参数调用:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
。callback_kwargs
将包含callback_on_step_end_tensor_inputs
指定的所有张量列表。 - callback_on_step_end_tensor_inputs (
List
, 可选) —callback_on_step_end
函数的张量输入列表。列表中指定的张量将作为callback_kwargs
参数传递。您只能包含管道类._callback_tensor_inputs
属性中列出的变量。 - max_sequence_length (
int
默认为300
) — 与prompt
一起使用的最大序列长度。 - complex_human_instruction (
List[str]
, 可选) — 复杂人类注意力的指令:https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55。
返回
SanaPipelineOutput 或 tuple
如果 return_dict
为 True
,则返回 SanaPipelineOutput,否则返回一个 tuple
,其中第一个元素是生成的图像列表
调用管道进行生成时调用的函数。
示例
>>> import torch
>>> from diffusers import SanaSprintImg2ImgPipeline
>>> from diffusers.utils.loading_utils import load_image
>>> pipe = SanaSprintImg2ImgPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
... )
>>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0]
>>> image[0].save("output.png")
禁用切片 VAE 解码。如果之前启用了 enable_vae_slicing
,此方法将返回一步计算解码。
禁用平铺 VAE 解码。如果之前启用了 enable_vae_tiling
,此方法将恢复一步计算解码。
启用切片 VAE 解码。启用此选项后,VAE 会将输入张量分片,分步计算解码。这有助于节省一些内存并允许更大的批次大小。
启用平铺 VAE 解码。启用此选项后,VAE 将把输入张量分割成瓦片,分多步计算编码和解码。这对于节省大量内存和处理更大的图像非常有用。
encode_prompt
< source >( prompt: typing.Union[str, typing.List[str]] num_images_per_prompt: int = 1 device: typing.Optional[torch.device] = None prompt_embeds: typing.Optional[torch.Tensor] = None prompt_attention_mask: typing.Optional[torch.Tensor] = None clean_caption: bool = False max_sequence_length: int = 300 complex_human_instruction: typing.Optional[typing.List[str]] = None lora_scale: typing.Optional[float] = None )
参数
- prompt (
str
或List[str]
, 可选) — 要编码的提示词 - num_images_per_prompt (
int
, 可选, 默认为 1) — 每个提示词应生成的图像数量 - device — (
torch.device
, 可选): 用于放置结果嵌入的torch设备。 - prompt_embeds (
torch.Tensor
, 可选) — 预生成的文本嵌入。可用于轻松调整文本输入,例如提示词权重。如果未提供,将从prompt
输入参数生成文本嵌入。 - clean_caption (
bool
, 默认为False
) — 如果为True
,函数将在编码前预处理并清理提供的标题。 - max_sequence_length (
int
, 默认为 300) — 用于提示词的最大序列长度。 - complex_human_instruction (
list[str]
, 默认为complex_human_instruction
) — 如果complex_human_instruction
不为空,函数将使用复杂的人类指令作为提示词。
将提示编码为文本编码器隐藏状态。
SanaPipelineOutput
class diffusers.pipelines.sana.pipeline_output.SanaPipelineOutput
< source >( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] )
Sana 管道的输出类。