介绍 Würstchen:用于图像生成的快速扩散模型

发布于 2023 年 9 月 13 日
在 GitHub 上更新

Collage of images created with Würstchen

什么是 Würstchen?

Würstchen 是一种扩散模型,其文本条件组件在高度压缩的图像潜在空间中工作。为什么这很重要?压缩数据可以将训练和推理的计算成本降低几个数量级。在 1024×1024 图像上训练比在 32×32 图像上训练昂贵得多。通常,其他工作使用相对较小的压缩,空间压缩范围为 4 倍 - 8 倍。Würstchen 将其推向极致。通过其新颖的设计,它实现了 42 倍的空间压缩!这以前从未见过,因为常见方法在 16 倍空间压缩后无法忠实地重建详细图像。Würstchen 采用两阶段压缩,我们称之为阶段 A 和阶段 B。阶段 A 是 VQGAN,阶段 B 是扩散自编码器(更多详细信息可以在论文中找到)。阶段 A 和 B 合称为*解码器*,因为它们将压缩图像解码回像素空间。第三个模型,阶段 C,在该高度压缩的潜在空间中学习。这种训练所需的计算量只是当前顶级模型所需计算量的一小部分,同时还允许更便宜、更快的推理。我们将阶段 C 称为*先验*。

Würstchen images with Prompts

为什么需要另一个文本到图像模型?

嗯,这个模型非常快且高效。Würstchen 的最大优势在于它可以比 Stable Diffusion XL 等模型更快地生成图像,同时使用更少的内存!因此,对于我们这些没有 A100 的人来说,这将非常有用。以下是与 SDXL 在不同批量大小下的比较:

Inference Speed Plots

此外,Würstchen 的另一个重要优势是降低了训练成本。在 512x512 分辨率下工作的 Würstchen v1 仅需要 9,000 GPU 小时的训练。与 Stable Diffusion 1.4 所花费的 150,000 GPU 小时相比,这表明成本降低了 16 倍,这不仅有利于研究人员进行新实验,还为更多组织训练此类模型打开了大门。Würstchen v2 使用了 24,602 GPU 小时。在分辨率达到 1536 的情况下,这仍然比仅在 512x512 分辨率下训练的 SD1.4 便宜 6 倍。

Inference Speed Plots

你也可以在这里找到详细的解释视频

如何使用 Würstchen?

你可以在这里尝试使用演示:

此外,该模型通过 Diffusers 库提供,因此你可以使用你已经熟悉的界面。例如,以下是如何使用 `AutoPipeline` 运行推理:

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")

caption = "Anthropomorphic cat dressed as a firefighter"
images = pipeline(
    caption,
    height=1024,
    width=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=4,
).images

Anthropomorphic cat dressed as a fire-fighter

Würstchen 适用于哪些图像尺寸?

Würstchen 在 1024x1024 到 1536x1536 之间的图像分辨率上进行训练。我们有时也会在 1024x2048 等分辨率下观察到不错的输出。欢迎随意尝试。我们还观察到先验(Stage C)能极快地适应新分辨率。因此,在 2048x2048 分辨率下进行微调的计算成本应该很低。

Hub 上的模型

所有检查点都可以在Huggingface Hub上找到。那里可以找到多个检查点,以及未来的演示和模型权重。目前,先验有 3 个检查点可用,解码器有 1 个检查点。请参阅文档,其中解释了检查点以及不同的先验模型的用途。

Diffusers 集成

因为 Würstchen 完全集成在 `diffusers` 中,所以它自动附带了各种开箱即用的便利功能和优化。其中包括:

  • 自动使用PyTorch 2 `SDPA`加速注意力,如下所述。
  • 如果需要使用 PyTorch 1.x 而非 2,则支持xFormers flash attention实现。
  • 模型卸载,以便在不使用时将未使用的组件移至 CPU。这可以节省内存,同时性能影响可忽略不计。
  • 顺序 CPU 卸载,适用于内存非常宝贵的情况。内存使用将最小化,但推理速度会变慢。
  • 使用 Compel 库进行提示权重
  • 支持 Apple Silicon Mac 上的`mps` 设备
  • 使用生成器实现可复现性
  • 针对推理的合理默认值,可在大多数情况下生成高质量结果。当然,您可以根据需要调整所有参数!

优化技术 1:Flash Attention

从 2.0 版本开始,PyTorch 集成了高度优化且资源友好的注意力机制版本,称为`torch.nn.functional.scaled_dot_product_attention` 或 SDPA。根据输入的性质,此函数利用多种底层优化。其性能和内存效率超越了传统的注意力模型。值得注意的是,SDPA 函数反映了 Dao 及其团队撰写的《Fast and Memory-Efficient Exact Attention with IO-Awareness》研究论文中强调的 *flash attention* 技术的特点。

如果您使用的是 PyTorch 2.0 或更高版本的 Diffusers,并且 SDPA 函数可访问,则会自动应用这些增强功能。请根据官方指南设置 torch 2.0 或更高版本以开始使用!

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

有关 `diffusers` 如何利用 SDPA 的深入了解,请查阅文档

如果您使用的是 Pytorch 2.0 之前的版本,仍然可以使用 xFormers 库来实现内存高效的注意力机制。

pipeline.enable_xformers_memory_efficient_attention()

优化技术 2:Torch Compile

如果你正在寻求额外的性能提升,可以使用 `torch.compile`。最好将其应用于先验模型和解码器的主模型,以实现最大的性能提升。

pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)

请记住,首次推理步骤将花费很长时间(长达 2 分钟),因为模型正在编译。之后,您可以正常运行推理

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

好消息是这种编译是一次性执行的。之后,您将始终体验到相同图像分辨率下的更快推理。编译的初始时间投入很快就会被随后的速度优势抵消。有关 `torch.compile` 及其细微差别的深入探讨,请查看官方文档

模型是如何训练的?

能够训练这个模型,完全得益于 Stability AI 提供的计算资源。我们要特别感谢 Stability 让我们能够进行这类研究,并有机会让更多人接触到它!

资源

  • 有关此模型的更多信息,请参阅官方 Diffusers 文档
  • 所有检查点都可以在hub上找到
  • 你可以在这里试用演示。
  • 如果你想讨论未来的项目或贡献自己的想法,请加入我们的 Discord
  • 训练代码等更多内容可以在官方 GitHub 存储库中找到。

社区

注册登录 以发表评论