在 Cloud TPU v5e 上使用 JAX 加速 Stable Diffusion XL 推理

发布于 2023 年 10 月 3 日
在 GitHub 上更新

像 Stable Diffusion XL (SDXL) 这样的生成式 AI 模型能够创建高质量、逼真的内容,并具有广泛的应用。然而,要发挥这些模型的力量也带来了巨大的挑战和计算成本。SDXL 是一个大型的图像生成模型,其 UNet 组件的大小约为该模型先前版本的三倍。由于内存需求和推理时间的增加,在生产环境中部署这样的模型具有挑战性。今天,我们激动地宣布,Hugging Face Diffusers 现在支持在 Cloud TPU 上使用 JAX 来服务 SDXL,从而实现高性能、高性价比的推理。

Google Cloud TPU 是定制化设计的 AI 加速器,专为大型 AI 模型(包括最先进的 LLM 和生成式 AI 模型如 SDXL)的训练和推理进行了优化。新的 Cloud TPU v5e 专为大规模 AI 训练推理提供所需的成本效益和性能。TPU v5e 的成本不到 TPU v4 的一半,使得更多的组织能够训练和部署 AI 模型。

🧨 Diffusers JAX 集成提供了一种通过 XLA 在 TPU 上运行 SDXL 的便捷方式,我们构建了一个演示来展示它。您可以在此 Space 或下面嵌入的 playground 中进行尝试

在底层,该演示运行在多个 TPU v5e-4 实例上(每个实例有 4 个 TPU 芯片),并利用并行化技术在大约 4 秒内提供四张 1024×1024 的大图。这个时间包括格式转换、通信时间和前端处理;实际的生成时间约为 2.3 秒,我们将在下面看到!

在这篇博文中,

  1. 我们描述了为什么 JAX + TPU + Diffusers 是运行 SDXL 的强大框架
  2. 解释如何使用 Diffusers 和 JAX 编写一个简单的图像生成流水线
  3. 展示比较不同 TPU 设置的基准测试

为什么选择 JAX + TPU v5e 来运行 SDXL?

通过专用 TPU 硬件和为性能优化的软件栈相结合,可以在 Cloud TPU v5e 上使用 JAX 以高性能和高成本效益地服务 SDXL。下面我们强调两个关键因素:JAX 的即时(jit)编译和使用 JAX pmap 实现的 XLA 编译器驱动的并行化。

即时编译

JAX 的一个显著特点是其即时(jit)编译。JIT 编译器在第一次运行时跟踪代码,并生成高度优化的 TPU 二进制文件,这些文件在后续调用中被重用。这个过程的要点在于,它要求所有输入、中间和输出的形状都是**静态**的,这意味着它们必须是预先知道的。每当我们改变形状,就会再次触发一个新的、代价高昂的编译过程。JIT 编译非常适合那些可以围绕静态形状设计的服务:编译只运行一次,然后我们就可以享受超快的推理速度。

图像生成非常适合 JIT 编译。如果我们总是生成相同数量且大小相同的图像,那么输出形状就是恒定且预先知道的。文本输入也是恒定的:按照设计,Stable Diffusion 和 SDXL 使用固定形状的嵌入向量(带有填充)来表示用户输入的提示。因此,我们可以编写依赖于固定形状的 JAX 代码,从而可以被极大地优化!

针对高批次大小的高性能吞吐量

使用 JAX 的 pmap,可以将工作负载扩展到多个设备上,它表达了单程序多数据(SPMD)程序。将 pmap 应用于一个函数会使用 XLA 编译该函数,然后在各种 XLA 设备上并行执行它。对于文本到图像的生成工作负载,这意味着同时增加渲染的图像数量很容易实现,并且不会影响性能。例如,在有 8 个芯片的 TPU 上运行 SDXL 将在与 1 个芯片创建单个图像相同的时间内生成 8 张图像。

TPU v5e 实例有多种规格,包括 1、4 和 8 芯片的配置,一直到 256 个芯片(一个完整的 TPU v5e pod),芯片之间有超快的 ICI 链接。这允许您选择最适合您用例的 TPU 规格,并轻松利用 JAX 和 TPU 提供的并行性。

如何用 JAX 编写一个图像生成流水线

我们将一步步介绍使用 JAX 实现超快推理所需的代码!首先,让我们导入依赖项。

# Show best practices for SDXL JAX
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time

现在,我们将加载基础 SDXL 模型以及推理所需的其他组件。diffusers 流水线会为我们处理下载和缓存所有内容。遵循 JAX 的函数式方法,模型的参数会单独返回,并且在推理时必须传递给流水线。

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)

默认情况下,模型参数以 32 位精度下载。为了节省内存并加快计算速度,我们会将它们转换为 bfloat16,这是一种高效的 16 位表示。然而,这里有一个注意事项:为了获得最佳效果,我们必须将_调度器状态_保持在 float32,否则精度误差会累积,导致低质量甚至全黑的图像。

scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state

我们现在准备好设置我们的提示和流水线的其余输入了。

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25

提示必须作为张量提供给流水线,并且它们在每次调用中必须具有相同的维度。这使得推理调用可以被编译。流水线的 prepare_inputs 方法为我们执行了所有必要的步骤,所以我们将创建一个辅助函数来准备我们的提示和负面提示作为张量。我们稍后会在 generate 函数中使用它。

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

为了利用并行化,我们将在设备之间复制输入。一个 Cloud TPU v5e-4 有 4 个芯片,所以通过复制输入,我们可以让每个芯片并行生成一个不同的图像。我们需要小心为每个芯片提供一个不同的随机种子,这样 4 张图像才会不同。

NUM_DEVICES = jax.device_count()

# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

我们现在准备将所有东西整合到一个生成函数中。

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

jit=True 表示我们希望编译流水线调用。这将在我们第一次调用 generate 时发生,并且会非常慢——JAX 需要跟踪操作,优化它们,并将其转换为低级原语。我们将运行第一次生成来完成这个过程并进行预热。

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

我们第一次运行时,这大约花了三分钟。但是一旦代码被编译,推理就会变得超快。让我们再试一次!

start = time.time()
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")

现在生成这 4 张图片只花了大约 2 秒!

基准测试

以下测量结果是在运行 SDXL 1.0 base 模型 20 个步骤,并使用默认的 Euler Discrete 调度器获得的。我们比较了相同批次大小下 Cloud TPU v5e 与 TPUv4 的性能。请注意,由于并行性,像我们在演示中使用的 TPU v5e-4,在使用批次大小为 1 时将生成 **4 张图像**(或在使用批次大小为 2 时生成 8 张图像)。同样,TPU v5e-8 在使用批次大小为 1 时将生成 8 张图像。

Cloud TPU 测试使用 Python 3.10 和 jax 0.4.16 版本进行。这些规格与我们的演示 Space 中使用的相同。

批量大小 延迟 性价比 (Perf/$)
TPU v5e-4 (JAX) 4 2.33 秒 21.46
8 4.99 秒 20.04
TPU v4-8 (JAX) 4 2.16 秒 9.05
8 4.17 8.98

TPU v5e 在 SDXL 上的性价比高达 TPU v4 的 2.4 倍,展示了最新一代 TPU 的成本效益。

为了衡量推理性能,我们使用行业标准的吞吐量指标。首先,我们测量模型编译和加载后每张图像的延迟。然后,我们通过将批次大小除以每个芯片的延迟来计算吞吐量。因此,吞吐量衡量的是模型在生产环境中的性能,无论使用多少芯片。然后,我们将吞吐量除以标价,得到单位成本的性能。

这个演示是如何工作的?

我们之前展示的演示是使用一个脚本构建的,该脚本基本上遵循了我们在这篇博文中发布的代​​码。它运行在几个各带 4 个芯片的 Cloud TPU v5e 设备上,还有一个简单的负载均衡服务器,随机将用户请求路由到后端服务器。当您在演示中输入提示时,您的请求将被分配到其中一个后端服务器,然后您将收到它生成的 4 张图像。

这是一个基于多个预分配 TPU 实例的简单解决方案。在未来的文章中,我们将介绍如何使用 GKE 创建适应负载的动态解决方案。

演示的所有代码都是开源的,并且现在可以在 Hugging Face Diffusers 中找到。我们很期待看到您使用 Diffusers + JAX + Cloud TPU 构建的作品!

社区

注册登录 发表评论