Diffusers 文档

JAX/Flax

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

JAX/Flax

🤗 Diffusers 支持 Flax,可在 Google TPU(例如 Colab、Kaggle 或 Google Cloud Platform 中提供的 TPU)上实现超快速推理。本指南将展示如何使用 JAX/Flax 运行 Stable Diffusion 推理。

在开始之前,请确保已安装必要的库

# uncomment to install the necessary libraries in Colab
#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
#!pip install -q diffusers

您还应确保您正在使用 TPU 后端。虽然 JAX 并非专门在 TPU 上运行,但在 TPU 上您将获得最佳性能,因为每个服务器都有 8 个 TPU 加速器并行工作。

如果您在 Colab 中运行本指南,请在上面的菜单中选择运行时,选择更改运行时类型选项,然后在硬件加速器设置下选择 TPU。导入 JAX 并快速检查您是否正在使用 TPU

import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type,
    "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
)
# Found 8 JAX devices of type Cloud TPU.

太棒了,现在您可以导入所需的其余依赖项了

import jax.numpy as jnp
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

加载模型

Flax 是一个函数式框架,因此模型是无状态的,参数存储在模型外部。加载预训练的 Flax pipeline 会同时返回 pipeline 和模型权重(或参数)。在本指南中,您将使用 bfloat16,这是一种更高效的半精度浮点类型,TPU 支持该类型(如果您需要完全精度,也可以使用 float32)。

dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    variant="bf16",
    dtype=dtype,
)

推理

TPU 通常有 8 个设备并行工作,因此让我们为每个设备使用相同的 prompt。这意味着您可以一次在 8 个设备上执行推理,每个设备生成一个图像。因此,您将在与单个芯片生成单个图像相同的时间内获得 8 个图像!

并行化如何工作?部分了解更多详情。

复制 prompt 后,通过在 pipeline 上调用 prepare_inputs 函数来获取分词文本 ID。分词文本的长度设置为 77 个 tokens,这是底层 CLIP 文本模型配置的要求。

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
# (8, 77)

模型参数和输入必须在 8 个并行设备之间复制。参数字典使用 flax.jax_utils.replicate 复制,该函数遍历字典并更改权重的形状,使其重复 8 次。数组使用 shard 复制。

# parameters
p_params = replicate(params)

# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
# (8, 1, 77)

此形状意味着 8 个设备中的每一个都收到一个形状为 (1, 77)jnp 数组作为输入,其中 1 是每个设备的批量大小。在具有足够内存的 TPU 上,如果您想一次生成多个图像(每个芯片),则可以使用大于 1 的批量大小。

接下来,创建一个随机数生成器,传递给生成函数。这是 Flax 中的标准过程,Flax 对随机数非常认真和固执己见。所有处理随机数的函数都应接收生成器,以确保可重复性,即使在跨多个分布式设备进行训练时也是如此。

下面的辅助函数使用种子来初始化随机数生成器。只要您使用相同的种子,您将获得完全相同的结果。在稍后指南中探索结果时,可以随意使用不同的种子。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

辅助函数或 rng 被拆分 8 次,因此每个设备都收到不同的生成器并生成不同的图像。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

为了利用 JAX 在 TPU 上的优化速度,请将 jit=True 传递给 pipeline,以将 JAX 代码编译为高效表示形式,并确保模型在 8 个设备上并行运行。

您需要确保所有后续调用中的输入都具有相同的形状,否则 JAX 将需要重新编译代码,这会比较慢。

第一次推理运行需要更多时间,因为它需要编译代码,但后续调用(即使使用不同的输入)会快得多。例如,在 TPU v2-8 上编译需要一分多钟,但随后的推理运行大约需要 7 秒

%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
# Wall time: 1min 29s

返回的数组形状为 (8, 1, 512, 512, 3),应将其重塑以删除第二个维度,并获得 8 个 512 × 512 × 3 的图像。然后,您可以使用 numpy_to_pil() 函数将数组转换为图像。

from diffusers.utils import make_image_grid

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, rows=2, cols=4)

img

使用不同的 prompts

您不一定必须在所有设备上使用相同的 prompt。例如,要生成 8 个不同的 prompts

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

make_image_grid(images, 2, 4)

img

并行化如何工作?

🤗 Diffusers 中的 Flax pipeline 会自动编译模型并在所有可用设备上并行运行它。让我们仔细看看这个过程是如何工作的。

JAX 并行化可以通过多种方式完成。最简单的一种是围绕使用 jax.pmap 函数来实现单程序多数据 (SPMD) 并行化。这意味着运行同一代码的多个副本,每个副本使用不同的数据输入。更复杂的方法也是可能的,如果您有兴趣,可以查阅 JAX 文档以更详细地探索此主题!

jax.pmap 做两件事

  1. 编译(或“jit”)代码,这类似于 jax.jit()。这不会在您调用 pmap 时发生,仅在第一次调用 pmap 映射的函数时发生。
  2. 确保编译后的代码在所有可用设备上并行运行。

为了演示,在 pipeline 的 _generate 方法上调用 pmap(这是一个私有方法,用于生成图像,可能会在 🤗 Diffusers 的未来版本中重命名或删除)

p_generate = pmap(pipeline._generate)

调用 pmap 后,准备好的函数 p_generate

  1. 在每个设备上复制底层函数 pipeline._generate
  2. 向每个设备发送输入参数的不同部分(这就是为什么需要调用 shard 函数)。在这种情况下,prompt_ids 的形状为 (8, 1, 77, 768),因此数组被拆分为 8 个部分,并且 _generate 的每个副本都收到形状为 (1, 77, 768) 的输入。

此处最需要注意的是批量大小(在本例中为 1)以及对您的代码有意义的输入维度。您无需更改任何其他内容即可使代码并行工作。

第一次调用 pipeline 需要更多时间,但之后的调用速度要快得多。block_until_ready 函数用于正确测量推理时间,因为 JAX 使用异步调度,并在可以时立即将控制权返回给 Python 循环。您无需在代码中使用它;当您想要使用尚未具体化的计算结果时,会自动发生阻塞。

%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()

# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
# Wall time: 1min 15s

检查您的图像尺寸以查看它们是否正确

images.shape
# (8, 1, 512, 512, 3)

资源

要了解有关 JAX 如何与 Stable Diffusion 协同工作的更多信息,您可能对阅读以下内容感兴趣

< > 在 GitHub 上更新