扩散课程文档

DDIM 逆转

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

Open In Colab

DDIM 逆转

在本笔记本中,我们将探讨 **逆转**,了解它与采样的关系,并将它应用于使用稳定扩散编辑图像的任务。

您将学到什么

  • DDIM 采样工作原理
  • 确定性与随机采样器
  • DDIM 逆转背后的理论
  • 使用逆转编辑图像

让我们开始吧!

设置

%pip install -q transformers diffusers accelerate
import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler


# Useful function for later
def load_image(url, size=None):
    response = requests.get(url, timeout=0.2)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    if size is not None:
        img = img.resize(size)
    return img
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

加载现有管道

# Load a pipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
# Set up a DDIM scheduler
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> # Sample an image to make sure it is all working
>>> prompt = "Beautiful DSLR Photograph of a penguin on the beach, golden hour"
>>> negative_prompt = "blurry, ugly, stock photo"
>>> im = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> im.resize((256, 256))  # Resize for convenient viewing

DDIM 采样

在给定的时间 $t$,噪声图像 $x_t$ 是原始图像 ($x_0$) 和一些噪声 ($\epsilon$) 的混合物。以下是 DDIM 论文中关于 $x_t$ 的公式,我们将在此部分中参考它xt=αtx0+1αtϵ x_t = \sqrt{\alpha_t}x_0 + \sqrt{1-\alpha_t}\epsilon

$\epsilon$ 是一个均值为 0,方差为 1 的高斯噪声。$\alpha_t$('alpha')是 DDPM 论文中混淆地称为 $\bar{\alpha}$('alpha_bar')的值,它定义了噪声调度器。在 Diffusers 中,alpha 调度器是计算出来的,其值存储在 scheduler.alphas_cumprod 中。我知道这很混乱!让我们绘制这些值,并记住在接下来的笔记本中,我们将使用 DDIM 的符号。

>>> # Plot 'alpha' (alpha_bar in DDPM language, alphas_cumprod in Diffusers for clarity)
>>> timesteps = pipe.scheduler.timesteps.cpu()
>>> alphas = pipe.scheduler.alphas_cumprod[timesteps]
>>> plt.plot(timesteps, alphas, label="alpha_t")
>>> plt.legend()

最初(时间步长为 0,图的左侧),我们从干净的图像开始,没有噪声。$\alpha_t = 1$。当我们移动到更高的时间步长时,我们最终会得到几乎全是噪声,$\alpha_t$ 降至 0。

在采样过程中,我们从时间步长为 1000 的纯噪声开始,并慢慢地移动到时间步长为 0。为了计算采样轨迹中的下一个 t($x{t-1}$ 因为我们是从高 t 到低 t 移动),我们预测噪声($\epsilon\theta(x_t)$,这是我们模型的输出),并用它来计算预测的去噪图像 $x_0$。然后,我们使用这个预测在“指向 $x_t$ 的方向”上移动一小段距离。最后,我们可以添加一些由 $\sigma_t$ 缩放的额外噪声。以下是从论文中摘录的相关部分,显示了它的实际操作

Screenshot from 2023-01-28 10-04-22.png

因此,我们有一个关于如何从 $xt$ 移动到 $x{t-1}$ 的公式,其中包含可控量的噪声。而今天我们特别关注的是不添加任何额外噪声的情况 - 这将为我们提供完全确定的 DDIM 采样。让我们看看它在代码中是什么样的

# Sample function (regular DDIM)
@torch.no_grad()
def sample(
    prompt,
    start_step=0,
    start_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    for i in tqdm(range(start_step, num_inference_steps)):

        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Normally we'd rely on the scheduler to handle the update step:
        # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Instead, let's do it ourselves:
        prev_t = max(1, t.item() - (1000 // num_inference_steps))  # t-1
        alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
        alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
        predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt

    # Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images
>>> # Test our sampling function by generating an image
>>> sample("Watercolor painting of a beach sunset", negative_prompt=negative_prompt, num_inference_steps=50)[0].resize(
...     (256, 256)
... )

看看你是否可以将代码与论文中的公式匹配起来。请注意,$\sigma$=0 因为我们只关注没有额外噪声的情况,因此我们可以省略公式的那些部分。

反转

反转的目的是“逆转”采样过程。我们希望最终得到一个噪声潜码,如果将其用作我们通常的采样过程的起点,就会生成原始图像。

在这里,我们将图像加载为我们的初始图像,但你也可以自己生成一个图像来代替它。

>>> # https://www.pexels.com/photo/a-beagle-on-green-grass-field-8306128/
>>> input_image = load_image("https://images.pexels.com/photos/8306128/pexels-photo-8306128.jpeg", size=(512, 512))
>>> input_image

我们还将使用一个提示词来执行反转,其中包含无分类器引导,所以请输入图像的描述。

input_image_prompt = "Photograph of a puppy on the grass"

接下来,我们需要将这个 PIL 图像转换成一组潜码,我们将用它作为我们反转的起点。

# Encode with VAE
with torch.no_grad():
    latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
l = 0.18215 * latent.latent_dist.sample()

好了,现在是有趣的部分了。这个函数与上面的采样函数很像,但我们以相反的方向遍历时间步长,从 t=0 开始,并逐渐增加噪声。而且,我们不是更新我们的潜码使其变得不那么嘈杂,而是估计预测的噪声,并用它来撤消更新步骤,将它们从 t 移动到 t+1。

## Inversion
@torch.no_grad()
def invert(
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=80,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)

在小狗图片的潜码表示上运行它,我们将得到一组在反转过程中创建的所有中间潜码。

inverted_latents = invert(l, input_image_prompt, num_inference_steps=50)
inverted_latents.shape

我们可以查看最终的潜码集 - 这些潜码有望成为我们新的采样尝试的噪声起点。

>>> # Decode the final inverted latents
>>> with torch.no_grad():
...     im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
>>> pipe.numpy_to_pil(im)[0]

你可以通过使用正常的 call 方法将这些反转后的潜码传递到管道中。

>>> pipe(input_image_prompt, latents=inverted_latents[-1][None], num_inference_steps=50, guidance_scale=3.5).images[0]

但在这里我们看到了第一个问题:这**并不完全是我们开始的图像**!这是因为 DDIM 反转依赖于一个关键假设,即时间 t 和时间 t+1 的噪声预测将是相同的 - 当我们只在 50 或 100 个时间步长上反转时,这是不正确的。我们可以使用更多的时间步长来获得更准确的反转,但我们也可以“作弊”,从例如采样的 20/50 步开始,并使用我们在反转过程中保存的相应中间潜码。

>>> # The reason we want to be able to specify start step
>>> start_step = 20
>>> sample(
...     input_image_prompt,
...     start_latents=inverted_latents[-(start_step + 1)][None],
...     start_step=start_step,
...     num_inference_steps=50,
... )[0]

非常接近我们的输入图像!我们为什么要这样做?好吧,希望如果我们现在用新的提示词进行采样,我们将得到一张图像,它与原始图像匹配,除了与新提示词相关的部分之外。例如,用“猫”替换“小狗”,我们应该看到一只猫,它有着几乎相同的草坪和背景。

>>> # Sampling with a new prompt
>>> start_step = 10
>>> new_prompt = input_image_prompt.replace("puppy", "cat")
>>> sample(
...     new_prompt,
...     start_latents=inverted_latents[-(start_step + 1)][None],
...     start_step=start_step,
...     num_inference_steps=50,
... )[0]

为什么不直接使用 img2img?

为什么要费心反转?我们不能直接在输入图像上添加噪声,然后用新的提示词进行去噪吗?我们可以,但这会导致在各个地方发生更加剧烈的变化(如果我们添加大量噪声)或者在任何地方都不够明显的变化(如果我们添加较少的噪声)。自己试试看。

>>> start_step = 10
>>> num_inference_steps = 50
>>> pipe.scheduler.set_timesteps(num_inference_steps)
>>> noisy_l = pipe.scheduler.add_noise(l, torch.randn_like(l), pipe.scheduler.timesteps[start_step])
>>> sample(new_prompt, start_latents=noisy_l, start_step=start_step, num_inference_steps=num_inference_steps)[0]

请注意草坪和背景的变化幅度更大。

将所有内容整合在一起

让我们将到目前为止编写的代码封装成一个简单的函数,该函数接受一张图像和两个提示词,并使用反转执行编辑。

def edit(input_image, input_image_prompt, edit_prompt, num_steps=100, start_step=30, guidance_scale=3.5):
    with torch.no_grad():
        latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
    l = 0.18215 * latent.latent_dist.sample()
    inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
    final_im = sample(
        edit_prompt,
        start_latents=inverted_latents[-(start_step + 1)][None],
        start_step=start_step,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
    )[0]
    return final_im

以及实际操作。

>>> edit(input_image, "A puppy on the grass", "an old grey dog on the grass", num_steps=50, start_step=10)
>>> edit(input_image, "A puppy on the grass", "A blue dog on the lawn", num_steps=50, start_step=12, guidance_scale=6)
# Exercise: Try this on some more images! Explore the different parameters.

更多步骤 = 更好的性能

如果您遇到反演精度较低的问题,可以尝试使用更多步骤(以增加运行时间为代价)。要测试反演,可以使用相同的提示使用我们的编辑功能。

>>> # Inversion test with far more steps
>>> edit(input_image, "A puppy on the grass", "A puppy on the grass", num_steps=350, start_step=1)

好多了!并尝试将其用于编辑。

>>> edit(
...     input_image,
...     "A photograph of a puppy",
...     "A photograph of a grey cat",
...     num_steps=150,
...     start_step=30,
...     guidance_scale=5.5,
... )
>>> # source: https://www.pexels.com/photo/girl-taking-photo-1493111/
>>> face = load_image("https://images.pexels.com/photos/1493111/pexels-photo-1493111.jpeg", size=(512, 512))
>>> face
>>> edit(
...     face,
...     "A photograph of a face",
...     "A photograph of a face with sunglasses",
...     num_steps=250,
...     start_step=30,
...     guidance_scale=3.5,
... )
>>> edit(
...     face,
...     "A photograph of a face",
...     "Acrylic palette knife painting of a face, colorful",
...     num_steps=250,
...     start_step=65,
...     guidance_scale=5.5,
... )

下一步做什么?

掌握了这个笔记本中的知识后,我建议您研究‘Null-text Inversion’,它基于 DDIM,通过在反演过程中优化空文本(无条件文本提示)来实现更准确的反演和更好的编辑。