DDIM 逆转
在本笔记本中,我们将探讨 **逆转**,了解它与采样的关系,并将它应用于使用稳定扩散编辑图像的任务。
您将学到什么
- DDIM 采样工作原理
- 确定性与随机采样器
- DDIM 逆转背后的理论
- 使用逆转编辑图像
让我们开始吧!
设置
%pip install -q transformers diffusers accelerate
import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Useful function for later
def load_image(url, size=None):
response = requests.get(url, timeout=0.2)
img = Image.open(BytesIO(response.content)).convert("RGB")
if size is not None:
img = img.resize(size)
return img
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
加载现有管道
# Load a pipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
# Set up a DDIM scheduler
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> # Sample an image to make sure it is all working
>>> prompt = "Beautiful DSLR Photograph of a penguin on the beach, golden hour"
>>> negative_prompt = "blurry, ugly, stock photo"
>>> im = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> im.resize((256, 256)) # Resize for convenient viewing
DDIM 采样
在给定的时间 $t$,噪声图像 $x_t$ 是原始图像 ($x_0$) 和一些噪声 ($\epsilon$) 的混合物。以下是 DDIM 论文中关于 $x_t$ 的公式,我们将在此部分中参考它
$\epsilon$ 是一个均值为 0,方差为 1 的高斯噪声。$\alpha_t$('alpha')是 DDPM 论文中混淆地称为 $\bar{\alpha}$('alpha_bar')的值,它定义了噪声调度器。在 Diffusers 中,alpha 调度器是计算出来的,其值存储在 scheduler.alphas_cumprod
中。我知道这很混乱!让我们绘制这些值,并记住在接下来的笔记本中,我们将使用 DDIM 的符号。
>>> # Plot 'alpha' (alpha_bar in DDPM language, alphas_cumprod in Diffusers for clarity)
>>> timesteps = pipe.scheduler.timesteps.cpu()
>>> alphas = pipe.scheduler.alphas_cumprod[timesteps]
>>> plt.plot(timesteps, alphas, label="alpha_t")
>>> plt.legend()
最初(时间步长为 0,图的左侧),我们从干净的图像开始,没有噪声。$\alpha_t = 1$。当我们移动到更高的时间步长时,我们最终会得到几乎全是噪声,$\alpha_t$ 降至 0。
在采样过程中,我们从时间步长为 1000 的纯噪声开始,并慢慢地移动到时间步长为 0。为了计算采样轨迹中的下一个 t($x{t-1}$ 因为我们是从高 t 到低 t 移动),我们预测噪声($\epsilon\theta(x_t)$,这是我们模型的输出),并用它来计算预测的去噪图像 $x_0$。然后,我们使用这个预测在“指向 $x_t$ 的方向”上移动一小段距离。最后,我们可以添加一些由 $\sigma_t$ 缩放的额外噪声。以下是从论文中摘录的相关部分,显示了它的实际操作
因此,我们有一个关于如何从 $xt$ 移动到 $x{t-1}$ 的公式,其中包含可控量的噪声。而今天我们特别关注的是不添加任何额外噪声的情况 - 这将为我们提供完全确定的 DDIM 采样。让我们看看它在代码中是什么样的
# Sample function (regular DDIM)
@torch.no_grad()
def sample(
prompt,
start_step=0,
start_latents=None,
guidance_scale=3.5,
num_inference_steps=30,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt="",
device=device,
):
# Encode prompt
text_embeddings = pipe._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Set num inference steps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
# Create a random starting point if we don't have one already
if start_latents is None:
start_latents = torch.randn(1, 4, 64, 64, device=device)
start_latents *= pipe.scheduler.init_noise_sigma
latents = start_latents.clone()
for i in tqdm(range(start_step, num_inference_steps)):
t = pipe.scheduler.timesteps[i]
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Normally we'd rely on the scheduler to handle the update step:
# latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
# Instead, let's do it ourselves:
prev_t = max(1, t.item() - (1000 // num_inference_steps)) # t-1
alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt
# Post-processing
images = pipe.decode_latents(latents)
images = pipe.numpy_to_pil(images)
return images
>>> # Test our sampling function by generating an image
>>> sample("Watercolor painting of a beach sunset", negative_prompt=negative_prompt, num_inference_steps=50)[0].resize(
... (256, 256)
... )
看看你是否可以将代码与论文中的公式匹配起来。请注意,$\sigma$=0 因为我们只关注没有额外噪声的情况,因此我们可以省略公式的那些部分。
反转
反转的目的是“逆转”采样过程。我们希望最终得到一个噪声潜码,如果将其用作我们通常的采样过程的起点,就会生成原始图像。
在这里,我们将图像加载为我们的初始图像,但你也可以自己生成一个图像来代替它。
>>> # https://www.pexels.com/photo/a-beagle-on-green-grass-field-8306128/
>>> input_image = load_image("https://images.pexels.com/photos/8306128/pexels-photo-8306128.jpeg", size=(512, 512))
>>> input_image
我们还将使用一个提示词来执行反转,其中包含无分类器引导,所以请输入图像的描述。
input_image_prompt = "Photograph of a puppy on the grass"
接下来,我们需要将这个 PIL 图像转换成一组潜码,我们将用它作为我们反转的起点。
# Encode with VAE
with torch.no_grad():
latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
l = 0.18215 * latent.latent_dist.sample()
好了,现在是有趣的部分了。这个函数与上面的采样函数很像,但我们以相反的方向遍历时间步长,从 t=0 开始,并逐渐增加噪声。而且,我们不是更新我们的潜码使其变得不那么嘈杂,而是估计预测的噪声,并用它来撤消更新步骤,将它们从 t 移动到 t+1。
## Inversion
@torch.no_grad()
def invert(
start_latents,
prompt,
guidance_scale=3.5,
num_inference_steps=80,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt="",
device=device,
):
# Encode prompt
text_embeddings = pipe._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Latents are now the specified start latents
latents = start_latents.clone()
# We'll keep a list of the inverted latents as the process goes on
intermediate_latents = []
# Set num inference steps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
# Reversed timesteps <<<<<<<<<<<<<<<<<<<<
timesteps = reversed(pipe.scheduler.timesteps)
for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):
# We'll skip the final iteration
if i >= num_inference_steps - 1:
continue
t = timesteps[i]
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
alpha_t = pipe.scheduler.alphas_cumprod[current_t]
alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
# Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
1 - alpha_t_next
).sqrt() * noise_pred
# Store
intermediate_latents.append(latents)
return torch.cat(intermediate_latents)
在小狗图片的潜码表示上运行它,我们将得到一组在反转过程中创建的所有中间潜码。
inverted_latents = invert(l, input_image_prompt, num_inference_steps=50)
inverted_latents.shape
我们可以查看最终的潜码集 - 这些潜码有望成为我们新的采样尝试的噪声起点。
>>> # Decode the final inverted latents
>>> with torch.no_grad():
... im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
>>> pipe.numpy_to_pil(im)[0]
你可以通过使用正常的 call 方法将这些反转后的潜码传递到管道中。
>>> pipe(input_image_prompt, latents=inverted_latents[-1][None], num_inference_steps=50, guidance_scale=3.5).images[0]
但在这里我们看到了第一个问题:这**并不完全是我们开始的图像**!这是因为 DDIM 反转依赖于一个关键假设,即时间 t 和时间 t+1 的噪声预测将是相同的 - 当我们只在 50 或 100 个时间步长上反转时,这是不正确的。我们可以使用更多的时间步长来获得更准确的反转,但我们也可以“作弊”,从例如采样的 20/50 步开始,并使用我们在反转过程中保存的相应中间潜码。
>>> # The reason we want to be able to specify start step
>>> start_step = 20
>>> sample(
... input_image_prompt,
... start_latents=inverted_latents[-(start_step + 1)][None],
... start_step=start_step,
... num_inference_steps=50,
... )[0]
非常接近我们的输入图像!我们为什么要这样做?好吧,希望如果我们现在用新的提示词进行采样,我们将得到一张图像,它与原始图像匹配,除了与新提示词相关的部分之外。例如,用“猫”替换“小狗”,我们应该看到一只猫,它有着几乎相同的草坪和背景。
>>> # Sampling with a new prompt
>>> start_step = 10
>>> new_prompt = input_image_prompt.replace("puppy", "cat")
>>> sample(
... new_prompt,
... start_latents=inverted_latents[-(start_step + 1)][None],
... start_step=start_step,
... num_inference_steps=50,
... )[0]
为什么不直接使用 img2img?
为什么要费心反转?我们不能直接在输入图像上添加噪声,然后用新的提示词进行去噪吗?我们可以,但这会导致在各个地方发生更加剧烈的变化(如果我们添加大量噪声)或者在任何地方都不够明显的变化(如果我们添加较少的噪声)。自己试试看。
>>> start_step = 10
>>> num_inference_steps = 50
>>> pipe.scheduler.set_timesteps(num_inference_steps)
>>> noisy_l = pipe.scheduler.add_noise(l, torch.randn_like(l), pipe.scheduler.timesteps[start_step])
>>> sample(new_prompt, start_latents=noisy_l, start_step=start_step, num_inference_steps=num_inference_steps)[0]
请注意草坪和背景的变化幅度更大。
将所有内容整合在一起
让我们将到目前为止编写的代码封装成一个简单的函数,该函数接受一张图像和两个提示词,并使用反转执行编辑。
def edit(input_image, input_image_prompt, edit_prompt, num_steps=100, start_step=30, guidance_scale=3.5):
with torch.no_grad():
latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
l = 0.18215 * latent.latent_dist.sample()
inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
final_im = sample(
edit_prompt,
start_latents=inverted_latents[-(start_step + 1)][None],
start_step=start_step,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
)[0]
return final_im
以及实际操作。
>>> edit(input_image, "A puppy on the grass", "an old grey dog on the grass", num_steps=50, start_step=10)
>>> edit(input_image, "A puppy on the grass", "A blue dog on the lawn", num_steps=50, start_step=12, guidance_scale=6)
# Exercise: Try this on some more images! Explore the different parameters.
更多步骤 = 更好的性能
如果您遇到反演精度较低的问题,可以尝试使用更多步骤(以增加运行时间为代价)。要测试反演,可以使用相同的提示使用我们的编辑功能。
>>> # Inversion test with far more steps
>>> edit(input_image, "A puppy on the grass", "A puppy on the grass", num_steps=350, start_step=1)
好多了!并尝试将其用于编辑。
>>> edit(
... input_image,
... "A photograph of a puppy",
... "A photograph of a grey cat",
... num_steps=150,
... start_step=30,
... guidance_scale=5.5,
... )
>>> # source: https://www.pexels.com/photo/girl-taking-photo-1493111/
>>> face = load_image("https://images.pexels.com/photos/1493111/pexels-photo-1493111.jpeg", size=(512, 512))
>>> face
>>> edit(
... face,
... "A photograph of a face",
... "A photograph of a face with sunglasses",
... num_steps=250,
... start_step=30,
... guidance_scale=3.5,
... )
>>> edit(
... face,
... "A photograph of a face",
... "Acrylic palette knife painting of a face, colorful",
... num_steps=250,
... start_step=65,
... guidance_scale=5.5,
... )
下一步做什么?
掌握了这个笔记本中的知识后,我建议您研究‘Null-text Inversion’,它基于 DDIM,通过在反演过程中优化空文本(无条件文本提示)来实现更准确的反演和更好的编辑。