扩散模型课程文档

DDIM 反演

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

DDIM 反演

在这个 notebook 中,我们将探讨反演,了解它与采样的关系,并将其应用于使用 Stable Diffusion 编辑图像的任务。

你将学到什么

  • DDIM 采样的工作原理
  • 确定性采样器 vs 随机性采样器
  • DDIM 反演背后的理论
  • 使用反演编辑图像

让我们开始吧!

环境设置

%pip install -q transformers diffusers accelerate
import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler


# Useful function for later
def load_image(url, size=None):
    response = requests.get(url, timeout=0.2)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    if size is not None:
        img = img.resize(size)
    return img
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

加载现有 pipeline

# Load a pipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
# Set up a DDIM scheduler
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> # Sample an image to make sure it is all working
>>> prompt = "Beautiful DSLR Photograph of a penguin on the beach, golden hour"
>>> negative_prompt = "blurry, ugly, stock photo"
>>> im = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> im.resize((256, 256))  # Resize for convenient viewing

DDIM 采样

在给定的时间 $t$,带噪图像 $x_t$ 是原始图像 ($x_0$) 和一些噪声 ($\epsilon$) 的混合。这是 DDIM 论文中 $x_t$ 的公式,我们将在本节中参考它。xt=αtx0+1αtϵ x_t = \sqrt{\alpha_t}x_0 + \sqrt{1-\alpha_t}\epsilon

$\epsilon$ 是单位方差的高斯噪声。$\alpha_t$('alpha')在 DDPM 论文中被令人困惑地称为 $\bar{\alpha}$('alpha_bar')(!!),并定义了噪声调度器。在 Diffusers 中,alpha 调度器被计算出来,其值存储在 `scheduler.alphas_cumprod` 中。我知道这很令人困惑!让我们绘制这些值,并记住在本 notebook 的其余部分,我们将使用 DDIM 的表示法。

>>> # Plot 'alpha' (alpha_bar in DDPM language, alphas_cumprod in Diffusers for clarity)
>>> timesteps = pipe.scheduler.timesteps.cpu()
>>> alphas = pipe.scheduler.alphas_cumprod[timesteps]
>>> plt.plot(timesteps, alphas, label="alpha_t")
>>> plt.legend()

最初(时间步 0,图的左侧),我们从一张清晰的图像开始,没有噪声,$\alpha_t = 1$。随着我们向更高的时间步移动,我们最终得到几乎全是噪声的图像,并且 $\alpha_t$ 趋向于 0。

在采样过程中,我们从时间步 1000 的纯噪声开始,慢慢向时间步 0 移动。为了计算采样轨迹中的下一个 t($x_{t-1}$,因为我们从高 t 移动到低 t),我们预测噪声($\epsilon_\theta(x_t)$,这是我们模型的输出),并用它来计算预测的去噪图像 $x_0$。然后我们使用这个预测在“指向 $x_t$ 的方向”上移动一小段距离。最后,我们可以添加一些由 $\sigma_t$ 缩放的额外噪声。以下是论文中展示这一过程的相关部分。

Screenshot from 2023-01-28 10-04-22.png

所以,我们有了一个从 $x_t$ 移动到 $x_{t-1}$ 的方程,并且可以控制噪声量。今天我们特别感兴趣的是不添加任何额外噪声的情况——这给了我们完全确定性的 DDIM 采样。让我们看看这在代码中是什么样子。

# Sample function (regular DDIM)
@torch.no_grad()
def sample(
    prompt,
    start_step=0,
    start_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    for i in tqdm(range(start_step, num_inference_steps)):

        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Normally we'd rely on the scheduler to handle the update step:
        # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Instead, let's do it ourselves:
        prev_t = max(1, t.item() - (1000 // num_inference_steps))  # t-1
        alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
        alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
        predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt

    # Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images
>>> # Test our sampling function by generating an image
>>> sample("Watercolor painting of a beach sunset", negative_prompt=negative_prompt, num_inference_steps=50)[0].resize(
...     (256, 256)
... )

看看你是否能将代码与论文中的方程对应起来。请注意,$\sigma$=0,因为我们只对没有额外噪声的情况感兴趣,所以我们可以省略方程中的那些部分。

反演

反演的目标是“逆转”采样过程。我们希望最终得到一个带噪的潜变量,如果将其用作我们常规采样过程的起点,将生成原始图像。

在这里,我们将加载一张图像作为我们的初始图像,但你也可以自己生成一张来代替。

>>> # https://www.pexels.com/photo/a-beagle-on-green-grass-field-8306128/
>>> input_image = load_image("https://images.pexels.com/photos/8306128/pexels-photo-8306128.jpeg", size=(512, 512))
>>> input_image

我们还将使用一个提示(prompt)来进行包含无分类器引导的反演,所以请输入图像的描述。

input_image_prompt = "Photograph of a puppy on the grass"

接下来,我们需要将这个 PIL 图像转换成一组潜变量,我们将用它作为我们反演的起点。

# Encode with VAE
with torch.no_grad():
    latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
l = 0.18215 * latent.latent_dist.sample()

好了,到了有趣的部分。这个函数看起来与上面的采样函数相似,但我们以相反的方向遍历时间步,从 t=0 开始,朝着越来越高的噪声移动。并且,我们不是更新我们的潜变量使其噪声更少,而是估计预测的噪声并用它来“撤销”一个更新步骤,将它们从 t 移动到 t+1。

## Inversion
@torch.no_grad()
def invert(
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=80,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)

在我们小狗图片的潜变量表示上运行它,我们会得到在反演过程中创建的所有中间潜变量的集合。

inverted_latents = invert(l, input_image_prompt, num_inference_steps=50)
inverted_latents.shape

我们可以查看最终的潜变量集——这些有望成为我们新采样尝试的带噪起点。

>>> # Decode the final inverted latents
>>> with torch.no_grad():
...     im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
>>> pipe.numpy_to_pil(im)[0]

你可以使用常规的 call 方法将这些反演后的潜变量传递给 pipeline。

>>> pipe(input_image_prompt, latents=inverted_latents[-1][None], num_inference_steps=50, guidance_scale=3.5).images[0]

但在这里我们看到了第一个问题:这并不完全是我们开始时的图像!这是因为 DDIM 反演依赖于一个关键假设,即时间 t 和时间 t+1 的噪声预测是相同的——当我们只在 50 或 100 个时间步上进行反演时,这个假设并不成立。我们可以使用更多的时间步来希望能得到更准确的反演,但我们也可以“作弊”,例如,从采样过程的 20/50 步开始,并使用我们在反演过程中保存的相应中间潜变量。

>>> # The reason we want to be able to specify start step
>>> start_step = 20
>>> sample(
...     input_image_prompt,
...     start_latents=inverted_latents[-(start_step + 1)][None],
...     start_step=start_step,
...     num_inference_steps=50,
... )[0]

非常接近我们的输入图像!我们为什么要这样做呢?嗯,希望是如果我们现在用一个新的提示进行采样,我们会得到一张与原始图像相匹配的图片,除了与新提示相关的地方。例如,用“猫”替换“小狗”,我们应该会看到一只猫,而草坪和背景几乎完全相同。

>>> # Sampling with a new prompt
>>> start_step = 10
>>> new_prompt = input_image_prompt.replace("puppy", "cat")
>>> sample(
...     new_prompt,
...     start_latents=inverted_latents[-(start_step + 1)][None],
...     start_step=start_step,
...     num_inference_steps=50,
... )[0]

为什么不直接使用 img2img?

为什么要费力进行反演?我们不可以直接向输入图像添加噪声,然后用新的提示进行去噪吗?我们可以,但这会导致各处发生更剧烈的变化(如果我们添加大量噪声)或者任何地方的变化都不够(如果我们添加较少噪声)。自己试试吧。

>>> start_step = 10
>>> num_inference_steps = 50
>>> pipe.scheduler.set_timesteps(num_inference_steps)
>>> noisy_l = pipe.scheduler.add_noise(l, torch.randn_like(l), pipe.scheduler.timesteps[start_step])
>>> sample(new_prompt, start_latents=noisy_l, start_step=start_step, num_inference_steps=num_inference_steps)[0]

注意草坪和背景发生了更大的变化。

整合一切

让我们把到目前为止编写的代码封装成一个简单的函数,它接受一张图像和两个提示,并使用反演执行编辑。

def edit(input_image, input_image_prompt, edit_prompt, num_steps=100, start_step=30, guidance_scale=3.5):
    with torch.no_grad():
        latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
    l = 0.18215 * latent.latent_dist.sample()
    inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
    final_im = sample(
        edit_prompt,
        start_latents=inverted_latents[-(start_step + 1)][None],
        start_step=start_step,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
    )[0]
    return final_im

这是实际操作的效果。

>>> edit(input_image, "A puppy on the grass", "an old grey dog on the grass", num_steps=50, start_step=10)
>>> edit(input_image, "A puppy on the grass", "A blue dog on the lawn", num_steps=50, start_step=12, guidance_scale=6)
# Exercise: Try this on some more images! Explore the different parameters.

更多步骤 = 更好性能

如果你在反演精度较低时遇到问题,可以尝试使用更多的步骤(代价是更长的运行时间)。要测试反演效果,你可以使用我们的编辑函数并传入相同的提示。

>>> # Inversion test with far more steps
>>> edit(input_image, "A puppy on the grass", "A puppy on the grass", num_steps=350, start_step=1)

好多了!再试试用它进行编辑。

>>> edit(
...     input_image,
...     "A photograph of a puppy",
...     "A photograph of a grey cat",
...     num_steps=150,
...     start_step=30,
...     guidance_scale=5.5,
... )
>>> # source: https://www.pexels.com/photo/girl-taking-photo-1493111/
>>> face = load_image("https://images.pexels.com/photos/1493111/pexels-photo-1493111.jpeg", size=(512, 512))
>>> face
>>> edit(
...     face,
...     "A photograph of a face",
...     "A photograph of a face with sunglasses",
...     num_steps=250,
...     start_step=30,
...     guidance_scale=3.5,
... )
>>> edit(
...     face,
...     "A photograph of a face",
...     "Acrylic palette knife painting of a face, colorful",
...     num_steps=250,
...     start_step=65,
...     guidance_scale=5.5,
... )

接下来呢?

掌握了这个 notebook 的知识后,我建议你研究一下 “空文本反演” (Null-text Inversion),它在 DDIM 的基础上,通过在反演过程中优化空文本(无条件文本提示)来实现更准确的反演和更好的编辑效果。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.