扩散课程文档
从零开始的扩散模型
并获得增强的文档体验
开始
从零开始的扩散模型
有时,考虑某事物的最简单版本有助于更好地理解其工作原理。我们将在本笔记本中尝试这样做,从一个“玩具”扩散模型开始,了解各个部分是如何工作的,然后研究它们与更复杂实现的区别。
我们将研究
- 损坏过程(向数据添加噪声)
- 什么是 UNet,以及如何从头开始实现一个极其简化的 UNet
- 扩散模型训练
- 采样理论
然后,我们将把我们的版本与 diffusers DDPM 实现进行比较,探索
- 对我们迷你 UNet 的改进
- DDPM 噪声计划
- 训练目标的差异
- 时间步条件控制
- 采样方法
本笔记本内容相当深入,如果您对从头开始的深入研究不感兴趣,可以安全地跳过!
另请注意,这里的大部分代码仅用于说明目的,我不建议直接将其用于您自己的工作(除非您只是为了学习目的而尝试改进此处显示的示例)。
设置和导入:
>>> %pip install -q diffusers
[K |████████████████████████████████| 255 kB 16.0 MB/s [K |████████████████████████████████| 163 kB 53.9 MB/s [?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")
Using device: cuda
数据
在这里,我们将使用一个非常小的数据集 mnist 来测试。如果您想在不更改任何其他内容的情况下给模型一个稍微困难的挑战,torchvision.datasets.FashionMNIST 应该可以作为直接替代品。
>>> dataset = torchvision.datasets.MNIST(
... root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Input shape: torch.Size([8, 1, 28, 28]) Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每张图像都是一个 28px x 28px 的数字灰度图,值范围从 0 到 1。
损坏过程
假设您没有读过任何扩散模型论文,但您知道该过程涉及添加噪声。你会怎么做?
我们可能需要一种简单的方法来控制损坏量。那么,如果我们使用一个参数来表示要添加的噪声量,然后我们执行
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount = 0,我们得到没有做任何更改的输入。如果 amount 达到 1,我们得到没有输入 x 痕迹的噪声。通过以这种方式将输入与噪声混合,我们将输出保持在相同的范围(0 到 1)。
我们可以相当容易地实现这一点(只需注意形状,这样就不会被广播规则所困扰)
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x * (1 - amount) + noise * amount
并直观地查看结果,以查看它是否按预期工作
>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")
随着噪声量接近 1,我们的数据开始看起来像纯粹的随机噪声。但是对于大多数噪声量,您可以相当好地猜出数字。你认为这是最佳的吗?
模型
我们需要一个模型,它可以接收 28px 的噪声图像,并输出相同形状的预测。这里的一个流行选择是称为 UNet 的架构。UNet 最初是为医学图像中的分割任务而发明的,它由一个“收缩路径”(数据通过该路径被压缩)和一个“扩张路径”(数据通过该路径扩展回原始维度,类似于自编码器)组成,但也具有跳跃连接,允许信息和梯度在不同级别之间流动。
一些 UNet 在每个阶段都具有复杂的块,但对于这个玩具演示,我们将构建一个最小的示例,它接收一个单通道图像,并通过下行路径(图表和代码中的 down_layers)上的三个卷积层和上行路径上的三个卷积层,并在下行层和上行层之间具有跳跃连接。我们将使用最大池化进行下采样,并使用 nn.Upsample 进行上采样,而不是像更复杂的 UNet 那样依赖于可学习的层。 这是粗略的架构,显示了每个层输出中的通道数
这是代码中的样子
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList(
[
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
]
)
self.up_layers = torch.nn.ModuleList(
[
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
]
)
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
我们可以验证输出形状与输入形状相同,正如我们所期望的那样
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
该网络拥有超过 300,000 个参数
sum([p.numel() for p in net.parameters()])
您可以探索更改每个层中的通道数,或者根据需要更换不同的架构。
训练网络
那么模型到底应该做什么呢?同样,对此有不同的看法,但对于此演示,让我们选择一个简单的框架:给定一个损坏的输入 noisy_x,模型应该输出其对原始 x 样子的最佳猜测。我们将通过均方误差将其与实际值进行比较。
现在我们可以尝试训练网络了。
- 获取一批数据
- 通过随机量损坏它
- 将其输入模型
- 将模型预测与干净图像进行比较,以计算我们的损失
- 相应地更新模型的参数。
随意修改此内容,看看是否可以使其工作得更好!
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)
>>> # Our loss function
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x)
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)
Finished epoch 0. Average loss for this epoch: 0.026736 Finished epoch 1. Average loss for this epoch: 0.020692 Finished epoch 2. Average loss for this epoch: 0.018887
我们可以尝试通过获取一批数据,用不同的量损坏它,然后查看模型的预测,来看看模型预测的样子
>>> # @markdown Visualizing model predictions on noisy inputs:
>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8] # Only using the first 8 for easy plotting
>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Get the model predictions
>>> with torch.no_grad():
... preds = net(noised_x.to(device)).detach().cpu()
>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")
您可以看到,对于较低的量,预测非常好!但是随着水平变得非常高,模型可以处理的东西就更少了,并且当我们达到 amount=1 时,它会输出一个模糊的混乱,接近数据集的平均值,以试图对输出可能的样子进行对冲……
采样
如果我们在高噪声水平下的预测不是很好,我们该如何生成图像呢?
嗯,如果我们从随机噪声开始,查看模型预测,但只朝着该预测方向移动少量 - 例如,20% 的路程,会怎么样呢?现在我们有了一个非常嘈杂的图像,其中可能有一点结构提示,我们可以将其输入到模型中以获得新的预测。希望是,这个新预测比第一个预测稍好(因为我们的起点噪声较小),因此我们可以使用这个新的、更好的预测再迈出一小步。
重复几次,(如果一切顺利)我们就会得到一张图像!这是在仅仅 5 个步骤中说明的过程,可视化了每个阶段的模型输入(左)和预测的去噪图像(右)。请注意,即使模型在第 1 步就预测了去噪图像,我们也只移动了 x 部分的路程。经过几个步骤,结构出现并被细化,直到我们获得最终输出。
>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device) # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []
>>> for i in range(n_steps):
... with torch.no_grad(): # No need to track gradients during inference
... pred = net(x) # Predict the denoised x0
... pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
... mix_factor = 1 / (n_steps - i) # How much we move towards the prediction
... x = x * (1 - mix_factor) + pred * mix_factor # Move part of the way there
... step_history.append(x.detach().cpu()) # Store step for plotting
>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
... axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
... axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")
我们可以将该过程分解为更多步骤,并希望以这种方式获得更好的图像
>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x)
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
不是很好,但那里有一些可识别的数字!您可以尝试训练更长时间(例如,10 或 20 个 epoch)并调整模型配置、学习率、优化器等。另外,不要忘记,如果您想要一个稍微困难的数据集来尝试,fashionMNIST 是一行代码的替代品。
与 DDPM 的比较
在本节中,我们将看看我们的玩具实现与另一个笔记本(Diffusers 介绍)中使用的方法有何不同,后者基于 DDPM 论文。
我们将看到
- diffusers UNet2DModel 比我们的 BasicUNet 更高级一些
- 损坏过程的处理方式不同
- 训练目标不同,涉及预测噪声而不是去噪图像
- 模型通过时间步条件控制来调节存在的噪声量,其中 t 作为 forward 方法的附加参数传递。
- 有许多不同的采样策略可用,它们应该比我们上面简单的版本工作得更好。
自 DDPM 论文发表以来,已经提出了一些改进,但希望这个例子能够说明不同的可用设计决策。一旦您阅读完本文,您可能会喜欢深入研究论文“Elucidating the Design Space of Diffusion-Based Generative Models”,该论文详细探讨了所有这些组件,并为如何获得最佳性能提出了新的建议。
如果所有这些都太技术性或令人生畏,请不要担心!您可以随意跳过本笔记本的其余部分,或将其留到以后再看。
UNet
diffusers UNet2DModel 模型比我们上面的基本 UNet 有许多改进
- GroupNorm 将组归一化应用于每个块的输入
- 用于更平滑训练的 Dropout 层
- 每个块多个 resnet 层(如果 layers_per_block 未设置为 1)
- 注意力机制(通常仅在较低分辨率的块中使用)
- 时间步条件控制。
- 具有可学习参数的下采样和上采样块
让我们创建并检查一个 UNet2DModel
>>> model = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... )
>>> print(model)
UNet2DModel( (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_proj): Timesteps() (time_embedding): TimestepEmbedding( (linear_1): Linear(in_features=32, out_features=128, bias=True) (act): SiLU() (linear_2): Linear(in_features=128, out_features=128, bias=True) ) (down_blocks): ModuleList( (0): DownBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (1): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (2): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) ) (up_blocks): ModuleList( (0): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (1): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (2): UpBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) ) ) ) (mid_block): UNetMidBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True) (conv_act): SiLU() (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )
正如您所看到的,还有更多内容!它也比我们的 BasicUNet 拥有明显更多的参数
sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet
我们可以使用此模型代替原始模型来复制上面显示的训练。我们需要将 x 和时间步都传递给模型(这里我总是传递 t=0 以表明它可以在没有时间步条件控制的情况下工作,并使采样代码易于操作,但您也可以尝试输入 (amount*1000) 以从损坏量中获得等效的时间步)。如果想要检查代码,更改的行会用 #<<< 显示。
>>> # @markdown Trying UNet2DModel instead of BasicUNet:
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... ) # <<<
>>> net.to(device)
>>> # Our loss finction
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x, 0).sample # <<< Using timestep 0 always, adding .sample
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))
>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")
>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x, 0).sample
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")
Finished epoch 0. Average loss for this epoch: 0.018925 Finished epoch 1. Average loss for this epoch: 0.012785 Finished epoch 2. Average loss for this epoch: 0.011694
这看起来比我们的第一组结果好得多!您可以探索调整 unet 配置或训练更长时间以获得更好的性能。
损坏过程
DDPM 论文描述了一个损坏过程,该过程为每个“时间步”添加少量噪声。给定某个时间步的 $x_{t-1}$,我们可以通过以下方式获得下一个(噪声稍大的)版本 $x_t$
$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$
也就是说,我们取 $x_{t-1}$,将其按 $\sqrt{1 - \beta_t}$ 缩放,并添加按 $\beta_t$ 缩放的噪声。这个 $\beta$ 是为每个 t 根据某个计划定义的,并确定每个时间步添加多少噪声。现在,我们不一定想执行此操作 500 次以获得 $x_{500}$,因此我们有另一个公式可以为任何 t(给定 $x_0$)获得 $x_t$
$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ where $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ and $\alpha_i = 1-\beta_i$
数学符号总是看起来很吓人!幸运的是,调度器为我们处理了所有这些(取消注释下一个单元格以查看代码)。我们可以绘制 $\sqrt{\bar{\alpha}_t}$(标记为 sqrt_alpha_prod)和 $\sqrt{(1 - \bar{\alpha}_t)}$(标记为 sqrt_one_minus_alpha_prod),以查看输入 (x) 和噪声如何在不同时间步长内缩放和混合
# ??noise_scheduler.add_noise
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")
最初,噪声 x 主要还是 x (sqrt_alpha_prod ~= 1),但随着时间的推移,x 的贡献下降,噪声分量增加。与我们根据 amount 对 x 和噪声进行线性混合不同,这种方法噪声增加得相对较快。我们可以在一些数据上可视化这一点
>>> # @markdown visualize the DDPM noising process for different timesteps:
>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0 # Map to (-1, 1)
>>> print("X shape", xb.shape)
>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")
>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb) # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)
>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")
X shape torch.Size([8, 1, 28, 28]) Noisy X shape torch.Size([8, 1, 28, 28])
另一个动态因素:DDPM 版本添加了从高斯分布(均值 0,标准差 1,来自 torch.randn)中提取的噪声,而不是我们在原始 corrupt 函数中使用的 0 到 1 之间的均匀噪声(来自 torch.rand)。一般来说,标准化训练数据也是有意义的。在另一个笔记本中,您会在变换列表中看到 Normalize(0.5, 0.5),它将图像数据形式从 (0, 1) 映射到 (-1, 1),并且对于我们的目的来说“足够好”。我们没有为本笔记本这样做,但上面的可视化单元格添加了它,以便进行更准确的缩放和可视化。
训练目标
在我们的玩具示例中,我们让模型尝试预测去噪图像。在 DDPM 和许多其他扩散模型实现中,模型预测损坏过程中使用的噪声(在缩放之前,因此是单位方差噪声)。在代码中,它看起来像
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target
您可能会认为,预测噪声(我们可以从中推导出去噪图像的样子)等同于直接预测去噪图像。那么,为什么要偏爱一种而不是另一种呢?仅仅是为了数学上的方便吗?
事实证明,这里还有另一个微妙之处。我们在训练期间跨不同的(随机选择的)时间步计算损失。这些不同的目标将导致这些损失的不同“隐式加权”,其中预测噪声会将更多权重放在较低的噪声水平上。您可以选择更复杂的目标来更改这种“隐式损失加权”。或者,您也可以选择一个噪声计划,该计划将导致更多在高噪声水平下的示例。也许您让模型预测“速度” v,我们将其定义为图像和噪声的组合,具体取决于噪声水平(请参阅“PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS”)。也许您让模型预测噪声,然后根据一些理论(请参阅“Perception Prioritized Training of Diffusion Models”)或基于试图查看哪些噪声水平对模型最有用的实验(请参阅“Elucidating the Design Space of Diffusion-Based Generative Models”),按某个取决于噪声量的因素缩放损失。 简而言之:选择目标会对模型性能产生影响,并且正在进行关于什么是“最佳”选择的研究。
目前,预测噪声(epsilon 或 eps,您会在某些地方看到)是首选方法,但随着时间的推移,我们可能会看到库中支持其他目标并在不同情况下使用。
时间步条件控制
UNet2DModel 接收 x 和时间步。后者被转换为嵌入并输入到模型的多个位置。
这背后的理论是,通过向模型提供有关噪声水平的信息,它可以更好地执行其任务。虽然可以在没有时间步条件控制的情况下训练模型,但在某些情况下它似乎确实有助于提高性能,并且至少在当前的文献中,大多数实现都包含它。
采样
给定一个模型,该模型估计噪声输入中存在的噪声(或预测去噪版本),我们如何生成新图像?
我们可以输入纯噪声,并希望模型在一个步骤中预测出一个好的图像作为去噪版本。但是,正如我们在上面的实验中看到的那样,这通常效果不佳。因此,我们根据模型预测采取一些较小的步骤,一次迭代地去除少量噪声。
我们如何采取这些步骤完全取决于所使用的采样方法。我们不会深入研究理论,但一些关键的设计问题是
- 您应该采取多大的步骤?换句话说,您应该遵循什么“噪声计划”?
- 您是否仅使用模型当前的预测来告知更新步骤(如 DDPM、DDIM 和许多其他方法)?您是否多次评估模型以估计更高阶梯度,以便进行更大、更准确的步骤(更高阶方法和一些离散 ODE 求解器)?或者,您是否保留过去预测的历史记录,以尝试更好地告知当前的更新步骤(线性多步和祖先采样器)?
- 您是否添加额外的噪声(有时称为 churn)以向采样过程添加更多随机性(随机性),还是保持其完全确定性?许多采样器使用参数(例如 DDIM 采样器的“eta”)来控制这一点,以便用户可以选择。
扩散模型采样方法的研究正在迅速发展,并且正在提出越来越多的方法,以便在更少的步骤中找到好的解决方案。勇敢而好奇的人可能会发现浏览 diffusers 库中可用的不同实现的 代码 或查看经常链接到相关论文的 文档 很有趣。
结论
希望这是一种从稍微不同的角度看待扩散模型的有益方式。
本笔记本由 Jonathan Whitaker 为 Hugging Face 课程编写,并与他自己的课程“The Generative Landscape”中包含的版本重叠。如果您想查看这个基本示例如何通过噪声和类条件控制进行扩展,请查看该课程。问题或错误可以通过 GitHub 问题或 Discord 进行沟通。也欢迎通过 Twitter @johnowhitaker 与我联系。
< > 在 GitHub 上更新