JAX/Flax
🤗 Diffusers 支持使用 Flax 在 Google TPU 上进行超快速推理,例如 Colab、Kaggle 或 Google Cloud Platform 上提供的 TPU。本指南将向您展示如何使用 JAX/Flax 运行 Stable Diffusion 推理。
在开始之前,请确保您已安装必要的库
# uncomment to install the necessary libraries in Colab
#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
#!pip install -q diffusers
您还应确保您正在使用 TPU 后端。虽然 JAX 不仅限于在 TPU 上运行,但您将在 TPU 上获得最佳性能,因为每个服务器都有 8 个 TPU 加速器并行工作。
如果您在 Colab 中运行本指南,请在上面的菜单中选择“运行时”,选择“更改运行时类型”选项,然后在“硬件加速器”设置下选择“TPU”。导入 JAX 并快速检查您是否正在使用 TPU
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
"TPU" in device_type,
"Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
)
# Found 8 JAX devices of type Cloud TPU.
很好,现在您可以导入您需要的其他依赖项了
import jax.numpy as jnp
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
加载模型
Flax 是一个函数式框架,因此模型是无状态的,参数存储在模型外部。加载预训练的 Flax 管道会返回管道和模型权重(或参数)。在本指南中,您将使用bfloat16
,这是一种更有效的半精度浮点数类型,受 TPU 支持(如果需要,您也可以使用float32
以获得完全精度)。
dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=dtype,
)
推理
TPU 通常有 8 个设备并行工作,因此让我们对每个设备使用相同的提示。这意味着您可以同时在 8 个设备上执行推理,每个设备生成一个图像。因此,您将在一个芯片生成单个图像所需的时间内获得 8 张图像!
在并行化是如何工作的?部分了解更多详细信息。
复制提示后,通过在管道上调用prepare_inputs
函数获取标记化的文本 ID。标记化文本的长度设置为 77 个标记,这是底层 CLIP 文本模型配置的要求。
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
# (8, 77)
模型参数和输入必须在 8 个并行设备之间复制。参数字典使用flax.jax_utils.replicate
进行复制,该函数遍历字典并更改权重的形状,使其重复 8 次。数组使用shard
进行复制。
# parameters
p_params = replicate(params)
# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
# (8, 1, 77)
此形状意味着 8 个设备中的每一个都接收一个形状为(1, 77)
的jnp
数组作为输入,其中1
是每个设备的批次大小。在内存充足的 TPU 上,如果要同时生成多个图像(每个芯片),则批次大小可以大于1
。
接下来,创建一个随机数生成器以传递给生成函数。这是 Flax 中的标准程序,它对随机数非常重视和有主见。所有处理随机数的函数都期望接收一个生成器以确保可重复性,即使您是在多个分布式设备上进行训练。
下面的辅助函数使用种子初始化随机数生成器。只要您使用相同的种子,您就会得到完全相同的结果。在稍后指南中探索结果时,可以随意使用不同的种子。
def create_key(seed=0):
return jax.random.PRNGKey(seed)
辅助函数或rng
被拆分为 8 次,以便每个设备接收一个不同的生成器并生成不同的图像。
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
为了利用 JAX 在 TPU 上的优化速度,将jit=True
传递给管道以将 JAX 代码编译成有效的表示形式,并确保模型在 8 个设备上并行运行。
您需要确保后续调用中的所有输入都具有相同的形状,否则 JAX 将需要重新编译代码,这会更慢。
第一次推理运行需要更多时间,因为它需要编译代码,但后续调用(即使输入不同)的速度要快得多。例如,在 TPU v2-8 上编译需要一分钟以上的时间,但在以后的推理运行中大约需要7 秒!
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
# Wall time: 1min 29s
返回的数组形状为(8, 1, 512, 512, 3)
,应将其重新整形以移除第二个维度并获得 8 张512 × 512 × 3
的图像。然后,您可以使用numpy_to_pil()函数将数组转换为图像。
from diffusers.utils import make_image_grid
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, rows=2, cols=4)
使用不同的提示
您不一定要在所有设备上使用相同的提示。例如,要生成 8 个不同的提示
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, 2, 4)
并行化是如何工作的?
🤗 Diffusers 中的 Flax 管道会自动编译模型并在所有可用设备上并行运行它。让我们更仔细地看看该过程是如何工作的。
JAX 并行化可以通过多种方式完成。最简单的方法是使用jax.pmap
函数实现单程序多数据 (SPMD) 并行化。这意味着在不同的数据输入上运行同一代码的多个副本。还可以使用更复杂的方法,如果您有兴趣,可以访问 JAX 的文档,更详细地了解此主题!
jax.pmap
执行两件事
- 编译(或“
jit
”)代码,这类似于jax.jit()
。这不会在您调用pmap
时发生,而只会在第一次调用pmap
ped 函数时发生。 - 确保编译后的代码在所有可用设备上并行运行。
为了演示,在管道的_generate
方法上调用pmap
(这是一个生成图像的私有方法,在 🤗 Diffusers 的未来版本中可能会重命名或删除)
p_generate = pmap(pipeline._generate)
在每个设备上复制底层函数pipeline._generate
。
- 在调用
pmap
后,准备好的函数p_generate
将 - 将输入参数的不同部分发送到每个设备(这就是需要调用shard函数的原因)。在本例中,
prompt_ids
的形状为(8, 1, 77, 768)
,因此数组被分成 8 部分,每个_generate
的副本接收形状为(1, 77, 768)
的输入。
这里最需要注意的是批次大小(在本例中为 1)以及对您的代码有意义的输入维度。您无需更改任何其他内容即可使代码并行运行。
第一次调用管道需要更多时间,但后续调用速度快得多。block_until_ready
函数用于正确测量推理时间,因为 JAX 使用异步调度并在其能够尽快将控制权返回给 Python 循环。您无需在代码中使用它;当您想要使用尚未实现的计算结果时,阻塞会自动发生。
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
# Wall time: 1min 15s
检查您的图像尺寸是否正确。
images.shape
# (8, 1, 512, 512, 3)
资源
若要详细了解 JAX 如何与 Stable Diffusion 协同工作,您可能希望阅读以下内容:
< > GitHub 更新