🧨 JAX / Flax 中的 Stable Diffusion!

发布于 2022 年 10 月 13 日
在 GitHub 上更新
Open In Colab

🤗 Hugging Face Diffusers0.5.1 版本开始支持 Flax!这使得在 Google TPU 上进行超快速推理成为可能,例如 Colab、Kaggle 或 Google Cloud Platform 中可用的 TPU。

这篇帖子展示了如何使用 JAX / Flax 运行推理。如果您想了解更多关于 Stable Diffusion 如何工作的详细信息,或者想在 GPU 上运行它,请参阅 此 Colab 笔记本

如果您想跟着操作,请点击上面的按钮,将此帖子作为 Colab 笔记本打开。

首先,请确保您正在使用 TPU 后端。如果您在 Colab 中运行此笔记本,请在上方菜单中选择 Runtime,然后选择“更改运行时类型”选项,然后在 Hardware accelerator 设置下选择 TPU

请注意,JAX 并非 TPU 独有,但它在这种硬件上表现出色,因为每个 TPU 服务器都有 8 个 TPU 加速器并行工作。

设置

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

输出:

    Found 8 JAX devices of type TPU v2.

请确保已安装 diffusers

!pip install diffusers==0.5.1

然后我们导入所有依赖项。

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

模型加载

在使用模型之前,您需要接受模型 许可证 才能下载和使用权重。

该许可证旨在减轻这种强大机器学习系统可能造成的有害影响。我们要求用户**完整并仔细阅读许可证**。以下是摘要:

  1. 您不得使用该模型故意生成或共享非法或有害的输出或内容,
  2. 我们不主张您生成的输出的任何权利,您可以自由使用它们,并对其使用负责,其使用不应违反许可证中规定的条款,并且
  3. 您可以重新分发权重并将其商业化和/或作为服务使用。如果您这样做,请注意您必须包含与许可证中相同的限制,并将 CreativeML OpenRAIL-M 的副本分享给所有用户。

Flax 权重作为 Stable Diffusion 仓库的一部分,在 Hugging Face Hub 中可用。Stable Diffusion 模型根据 CreateML OpenRail-M 许可证分发。这是一个开放许可证,不对您生成的输出主张任何权利,并禁止您故意生成非法或有害内容。模型卡提供了更多详细信息,请花点时间阅读并仔细考虑您是否接受该许可证。如果您接受,您需要成为 Hub 中的注册用户并使用访问令牌才能使代码正常工作。您有两种选择来提供您的访问令牌:

  • 在您的终端中使用 huggingface-cli login 命令行工具,并在提示时粘贴您的令牌。它将保存在您计算机上的文件中。
  • 或者在笔记本中使用 notebook_login(),它做的是同样的事情。

除非您之前已在此计算机上进行过身份验证,否则以下单元格将显示登录界面。您需要粘贴您的访问令牌。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPU 设备支持 bfloat16,一种高效的半精度浮点类型。我们将在测试中使用它,但您也可以使用 float32 来代替使用全精度。

dtype = jnp.bfloat16

Flax 是一个函数式框架,因此模型是无状态的,参数存储在模型之外。加载预训练的 Flax pipeline 将同时返回 pipeline 本身和模型权重(或参数)。我们正在使用 bf16 版本的权重,这会导致类型警告,您可以安全地忽略它们。

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

推理

由于 TPU 通常有 8 个设备并行工作,我们将把提示复制多次,以匹配设备的数量。然后我们将同时在 8 个设备上执行推理,每个设备负责生成一张图像。因此,我们将在单个芯片生成一张图像的相同时间内获得 8 张图像。

复制提示后,我们通过调用 pipeline 的 prepare_inputs 函数获得分词后的文本 ID。分词后的文本长度设置为 77 个 token,这是底层 CLIP Text 模型的配置所要求的。

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

输出:

    (8, 77)

复制与并行化

模型参数和输入必须在我们的 8 个并行设备之间复制。参数字典使用 flax.jax_utils.replicate 复制,该函数遍历字典并改变权重的形状,使其重复 8 次。数组使用 shard 复制。

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape

输出:

    (8, 1, 77)

这种形状意味着每个 8 个设备将接收一个形状为 (1, 77)jnp 数组作为输入。因此,1 是每个设备的批处理大小。在内存充足的 TPU 中,如果我们想一次生成多张图像(每个芯片),它可能会大于 1

我们几乎准备好生成图像了!我们只需要创建一个随机数生成器来传递给生成函数。这是 Flax 中的标准过程,它对随机数非常认真和有主见——所有处理随机数的函数都应该接收一个生成器。这确保了可重现性,即使我们在多个分布式设备上进行训练。

下面的辅助函数使用一个种子来初始化一个随机数生成器。只要我们使用相同的种子,我们将得到完全相同的结果。稍后在笔记本中探索结果时,随意使用不同的种子。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

我们获得一个随机数生成器,然后将其“分割”成 8 份,以便每个设备接收一个不同的生成器。因此,每个设备将创建一个不同的图像,并且整个过程是可重现的。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX 代码可以编译成高效的表示,运行速度非常快。然而,我们需要确保所有输入在后续调用中都具有相同的形状;否则,JAX 将不得不重新编译代码,我们将无法利用优化的速度。

如果我们将 jit = True 作为参数传递,Flax pipeline 可以为我们编译代码。它还将确保模型在 8 个可用设备上并行运行。

我们第一次运行以下单元格时,编译将花费很长时间,但随后的调用(即使输入不同)也会快得多。例如,我在 TPU v2-8 上测试时,编译花费了超过一分钟,但随后的推理运行仅需约 7秒

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

输出:

    CPU times: user 464 ms, sys: 105 ms, total: 569 ms
    Wall time: 7.07 s

返回的数组形状为 (8, 1, 512, 512, 3)。我们将其重塑以去除第二个维度,得到 8 张 512 × 512 × 3 的图像,然后将它们转换为 PIL 图像。

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

可视化

让我们创建一个辅助函数来以网格形式显示图像。

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 2, 4)

png

使用不同的提示

我们不必在所有设备上复制相同的提示。我们可以做任何我们想做的事情:生成 2 个提示,每个提示 4 次,甚至一次生成 8 个不同的提示。让我们开始吧!

首先,我们将输入准备代码重构为一个方便的函数

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)

png


并行化如何工作?

我们之前提到,diffusers Flax pipeline 会自动编译模型并在所有可用设备上并行运行。现在,我们将简要了解该过程的内部工作原理。

JAX 并行化可以通过多种方式完成。最简单的方法是使用 jax.pmap 函数来实现单程序多数据(SPMD)并行化。这意味着我们将在不同的数据输入上运行同一代码的多个副本。更复杂的方法也是可能的,如果您感兴趣,我们邀请您查阅 JAX 文档pjit 页面,以探索此主题!

jax.pmap 为我们做了两件事

  • 编译(或 jit)代码,就像我们调用了 jax.jit() 一样。这在调用 pmap 时不会发生,而是在第一次调用 pmapped 函数时发生。
  • 确保编译后的代码在所有可用设备上并行运行。

为了展示它的工作原理,我们使用 pmap 处理 pipeline 的 _generate 方法,这是运行生成图像的私有方法。请注意,此方法在未来的 diffusers 版本中可能会被重命名或删除。

p_generate = pmap(pipeline._generate)

使用 pmap 后,准备好的函数 p_generate 将在概念上执行以下操作

  • 在每个设备中调用底层函数 pipeline._generate 的副本。
  • 向每个设备发送输入参数的不同部分。这就是分片的目的。在我们的例子中,prompt_ids 的形状是 (8, 1, 77, 768)。此数组将被分成 8 份,每个 _generate 副本将接收一个形状为 (1, 77, 768) 的输入。

我们可以完全忽略它将并行调用的事实来编写 _generate。我们只关心我们的批处理大小(本例中为 1)和对我们的代码有意义的维度,无需更改任何内容即可使其并行工作。

与我们使用 pipeline 调用时一样,第一次运行以下单元格需要一段时间,但之后会快得多。

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

输出:

    CPU times: user 118 ms, sys: 83.9 ms, total: 202 ms
    Wall time: 6.82 s

    (8, 1, 512, 512, 3)

我们使用 block_until_ready() 来正确测量推理时间,因为 JAX 使用异步调度并在它能够返回 Python 循环时立即返回控制。您不需要在代码中使用它;当您想要使用尚未实现的计算结果时,阻塞会自动发生。

社区

注册登录 发表评论