介绍 Würstchen:用于图像生成的快速扩散模型

什么是 Würstchen?
Würstchen 是一种扩散模型,其文本条件组件在高度压缩的图像潜在空间中工作。为什么这很重要?压缩数据可以将训练和推理的计算成本降低几个数量级。在 1024×1024 图像上训练比在 32×32 图像上训练昂贵得多。通常,其他工作使用相对较小的压缩,空间压缩范围为 4 倍 - 8 倍。Würstchen 将其推向极致。通过其新颖的设计,它实现了 42 倍的空间压缩!这以前从未见过,因为常见方法在 16 倍空间压缩后无法忠实地重建详细图像。Würstchen 采用两阶段压缩,我们称之为阶段 A 和阶段 B。阶段 A 是 VQGAN,阶段 B 是扩散自编码器(更多详细信息可以在论文中找到)。阶段 A 和 B 合称为*解码器*,因为它们将压缩图像解码回像素空间。第三个模型,阶段 C,在该高度压缩的潜在空间中学习。这种训练所需的计算量只是当前顶级模型所需计算量的一小部分,同时还允许更便宜、更快的推理。我们将阶段 C 称为*先验*。
为什么需要另一个文本到图像模型?
嗯,这个模型非常快且高效。Würstchen 的最大优势在于它可以比 Stable Diffusion XL 等模型更快地生成图像,同时使用更少的内存!因此,对于我们这些没有 A100 的人来说,这将非常有用。以下是与 SDXL 在不同批量大小下的比较:
此外,Würstchen 的另一个重要优势是降低了训练成本。在 512x512 分辨率下工作的 Würstchen v1 仅需要 9,000 GPU 小时的训练。与 Stable Diffusion 1.4 所花费的 150,000 GPU 小时相比,这表明成本降低了 16 倍,这不仅有利于研究人员进行新实验,还为更多组织训练此类模型打开了大门。Würstchen v2 使用了 24,602 GPU 小时。在分辨率达到 1536 的情况下,这仍然比仅在 512x512 分辨率下训练的 SD1.4 便宜 6 倍。
你也可以在这里找到详细的解释视频
如何使用 Würstchen?
你可以在这里尝试使用演示:
此外,该模型通过 Diffusers 库提供,因此你可以使用你已经熟悉的界面。例如,以下是如何使用 `AutoPipeline` 运行推理:
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
caption = "Anthropomorphic cat dressed as a firefighter"
images = pipeline(
caption,
height=1024,
width=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_images_per_prompt=4,
).images
Würstchen 适用于哪些图像尺寸?
Würstchen 在 1024x1024 到 1536x1536 之间的图像分辨率上进行训练。我们有时也会在 1024x2048 等分辨率下观察到不错的输出。欢迎随意尝试。我们还观察到先验(Stage C)能极快地适应新分辨率。因此,在 2048x2048 分辨率下进行微调的计算成本应该很低。
Hub 上的模型
所有检查点都可以在Huggingface Hub上找到。那里可以找到多个检查点,以及未来的演示和模型权重。目前,先验有 3 个检查点可用,解码器有 1 个检查点。请参阅文档,其中解释了检查点以及不同的先验模型的用途。
Diffusers 集成
因为 Würstchen 完全集成在 `diffusers` 中,所以它自动附带了各种开箱即用的便利功能和优化。其中包括:
- 自动使用PyTorch 2 `SDPA`加速注意力,如下所述。
- 如果需要使用 PyTorch 1.x 而非 2,则支持xFormers flash attention实现。
- 模型卸载,以便在不使用时将未使用的组件移至 CPU。这可以节省内存,同时性能影响可忽略不计。
- 顺序 CPU 卸载,适用于内存非常宝贵的情况。内存使用将最小化,但推理速度会变慢。
- 使用 Compel 库进行提示权重。
- 支持 Apple Silicon Mac 上的`mps` 设备。
- 使用生成器实现可复现性。
- 针对推理的合理默认值,可在大多数情况下生成高质量结果。当然,您可以根据需要调整所有参数!
优化技术 1:Flash Attention
从 2.0 版本开始,PyTorch 集成了高度优化且资源友好的注意力机制版本,称为`torch.nn.functional.scaled_dot_product_attention` 或 SDPA。根据输入的性质,此函数利用多种底层优化。其性能和内存效率超越了传统的注意力模型。值得注意的是,SDPA 函数反映了 Dao 及其团队撰写的《Fast and Memory-Efficient Exact Attention with IO-Awareness》研究论文中强调的 *flash attention* 技术的特点。
如果您使用的是 PyTorch 2.0 或更高版本的 Diffusers,并且 SDPA 函数可访问,则会自动应用这些增强功能。请根据官方指南设置 torch 2.0 或更高版本以开始使用!
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
有关 `diffusers` 如何利用 SDPA 的深入了解,请查阅文档。
如果您使用的是 Pytorch 2.0 之前的版本,仍然可以使用 xFormers 库来实现内存高效的注意力机制。
pipeline.enable_xformers_memory_efficient_attention()
优化技术 2:Torch Compile
如果你正在寻求额外的性能提升,可以使用 `torch.compile`。最好将其应用于先验模型和解码器的主模型,以实现最大的性能提升。
pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
请记住,首次推理步骤将花费很长时间(长达 2 分钟),因为模型正在编译。之后,您可以正常运行推理
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
好消息是这种编译是一次性执行的。之后,您将始终体验到相同图像分辨率下的更快推理。编译的初始时间投入很快就会被随后的速度优势抵消。有关 `torch.compile` 及其细微差别的深入探讨,请查看官方文档。
模型是如何训练的?
能够训练这个模型,完全得益于 Stability AI 提供的计算资源。我们要特别感谢 Stability 让我们能够进行这类研究,并有机会让更多人接触到它!
资源
- 有关此模型的更多信息,请参阅官方 Diffusers 文档。
- 所有检查点都可以在hub上找到
- 你可以在这里试用演示。
- 如果你想讨论未来的项目或贡献自己的想法,请加入我们的 Discord!
- 训练代码等更多内容可以在官方 GitHub 存储库中找到。