扩散模型课程文档

从零开始构建扩散模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

从零开始构建扩散模型

有时候,为了更好地理解一个事物的工作原理,研究其最简单的可能版本会很有帮助。在本笔记本中,我们将尝试这样做,从一个“玩具级”的扩散模型开始,看看不同部分是如何工作的,然后检查它们与更复杂实现有何不同。

我们将探讨:

  • 破坏过程(向数据添加噪声)
  • 什么是 UNet,以及如何从零开始实现一个极其精简的 UNet
  • 扩散模型的训练
  • 采样理论

然后,我们会将我们的版本与 diffusers 的 DDPM 实现进行比较,并探讨:

  • 对我们迷你 UNet 的改进
  • DDPM 的噪声调度
  • 训练目标的差异
  • 时间步条件
  • 采样方法

这个笔记本内容相当深入,如果你对从零开始的深度探索不感兴趣,可以放心跳过!

同样值得注意的是,这里的大部分代码都是为了说明目的,我不建议直接在你的工作中使用它们(除非你只是为了学习目的而尝试改进这里展示的例子)。

设置与导入:

>>> %pip install -q diffusers
     |████████████████████████████████| 255 kB 16.0 MB/s 
     |████████████████████████████████| 163 kB 53.9 MB/s 
[?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")
Using device: cuda

数据

在这里,我们将使用一个非常小的数据集进行测试:mnist。如果你想在不改变其他任何设置的情况下给模型一个稍难的挑战,torchvision.datasets.FashionMNIST 应该可以作为直接的替代品。

>>> dataset = torchvision.datasets.MNIST(
...     root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])

每张图片是一幅 28x28 像素的灰度数字手写图,像素值范围从 0 到 1。

破坏过程

假设你没有读过任何关于扩散模型的论文,但你知道这个过程涉及到添加噪声。你会怎么做呢?

我们可能希望有一种简单的方法来控制破坏的程度。那么,如果我们引入一个参数 amount 来表示要添加的噪声量,然后这样做会怎么样?

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

如果 amount = 0,我们会得到没有任何改变的原始输入。如果 amount 达到 1,我们得到的是没有任何原始输入 x 痕迹的噪声。通过这种方式混合输入和噪声,我们能将输出保持在相同的范围(0 到 1)内。

我们可以相当容易地实现这一点(只需注意形状,以免被广播规则坑到):

def corrupt(x, amount):
    """Corrupt the input `x` by mixing it with noise according to `amount`"""
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)  # Sort shape so broadcasting works
    return x * (1 - amount) + noise * amount

并通过视觉化结果来检查它是否按预期工作:

>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")

>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0])  # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)

>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")

当噪声量接近 1 时,我们的数据开始看起来像纯粹的随机噪声。但对于大多数噪声量,你仍然可以相当准确地猜出数字。你认为这是最优的吗?

模型

我们希望有一个模型,它能接收 28 像素的带噪图像,并输出一个相同形状的预测结果。这里一个流行的选择是名为 UNet 的架构。UNet 最初是为医学影像中的分割任务而发明的,它包含一个“收缩路径”(数据在此路径上被压缩)和一个“扩展路径”(数据在此路径上恢复到原始维度,类似于自编码器),但它还具有跳跃连接,允许信息和梯度在不同层级之间流动。

一些 UNet 在每个阶段都具有复杂的模块,但对于这个玩具级演示,我们将构建一个极简的例子,它接收单通道图像,在下采样路径(图表和代码中的 down_layers)中通过三个卷积层,在上采样路径中通过三个卷积层,并且在下采样和上采样层之间有跳跃连接。我们将使用最大池化进行下采样,使用 nn.Upsample 进行上采样,而不是像更复杂的 UNet 那样依赖可学习的层。下面是大致的架构,显示了每一层输出的通道数:

unet_diag.png

这就是它在代码中的样子:

class BasicUNet(nn.Module):
    """A minimal UNet implementation."""

    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
            ]
        )
        self.up_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
            ]
        )
        self.act = nn.SiLU()  # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))  # Through the layer and the activation function
            if i < 2:  # For all but the third (final) down layer:
                h.append(x)  # Storing output for skip connection
                x = self.downscale(x)  # Downscale ready for the next layer

        for i, l in enumerate(self.up_layers):
            if i > 0:  # For all except the first up layer
                x = self.upscale(x)  # Upscale
                x += h.pop()  # Fetching stored output (skip connection)
            x = self.act(l(x))  # Through the layer and the activation function

        return x

我们可以验证输出形状与输入相同,正如我们所期望的:

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape

这个网络有超过 30 万个参数。

sum([p.numel() for p in net.parameters()])

如果你愿意,可以尝试改变每层的通道数或换用不同的架构。

训练网络

那么,这个模型具体应该做什么呢?同样,对此有不同的看法,但在这个演示中,我们选择一个简单的框架:给定一个被破坏的输入 noisy_x,模型应该输出它对原始 x 的最佳猜测。我们将通过均方误差(mean squared error)将这个猜测与真实值进行比较。

我们现在可以尝试训练网络了。

  • 获取一批数据
  • 以随机的量破坏它
  • 将其输入模型
  • 将模型预测与清晰图像进行比较,以计算我们的损失
  • 相应地更新模型的参数。

你可以随意修改这个过程,看看能否让它工作得更好!

>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

>>> # How many runs through the data should we do?
>>> n_epochs = 3

>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)

>>> # Our loss function
>>> loss_fn = nn.MSELoss()

>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)

>>> # Keeping a record of the losses for later viewing
>>> losses = []

>>> # The training loop
>>> for epoch in range(n_epochs):

...     for x, y in train_dataloader:

...         # Get some data and prepare the corrupted version
...         x = x.to(device)  # Data on the GPU
...         noise_amount = torch.rand(x.shape[0]).to(device)  # Pick random noise amounts
...         noisy_x = corrupt(x, noise_amount)  # Create our noisy x

...         # Get the model prediction
...         pred = net(noisy_x)

...         # Calculate the loss
...         loss = loss_fn(pred, x)  # How close is the output to the true 'clean' x?

...         # Backprop and update the params:
...         opt.zero_grad()
...         loss.backward()
...         opt.step()

...         # Store the loss for later
...         losses.append(loss.item())

...     # Print our the average of the loss values for this epoch:
...     avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
...     print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")

>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887

我们可以通过取一批数据,以不同的量破坏它,然后观察模型的预测结果,来看看模型的表现如何。

>>> # @markdown Visualizing model predictions on noisy inputs:

>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8]  # Only using the first 8 for easy plotting

>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0])  # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)

>>> # Get the model predictions
>>> with torch.no_grad():
...     preds = net(noised_x.to(device)).detach().cpu()

>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")

你可以看到,对于较低的噪声量,预测结果相当不错!但随着噪声水平变得非常高,模型可利用的信息越来越少,当噪声量达到 1 时,它会输出一个接近数据集均值的模糊图像,试图在不确定的情况下做出最稳妥的猜测……

采样

如果我们在高噪声水平下的预测不是很好,我们该如何生成图像呢?

嗯,如果我们从随机噪声开始,观察模型的预测,但只朝着那个预测移动一小部分——比如说,20% 的距离。现在我们有了一张非常嘈杂的图像,其中可能带有一丝结构,我们可以将其输入模型以获得新的预测。希望这个新的预测比第一个稍好一些(因为我们的起点噪声稍小),这样我们就可以用这个新的、更好的预测再迈出一小步。

重复几次,如果一切顺利,我们就能得到一张图像!这里展示了这个过程仅用 5 个步骤的图示,可视化了每个阶段模型的输入(左)和预测的去噪图像(右)。请注意,即使模型在第一步就预测了去噪图像,我们也只让 x 朝那个方向移动了一部分。经过几步,结构出现并得到完善,直到我们得到最终的输出。

>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device)  # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []

>>> for i in range(n_steps):
...     with torch.no_grad():  # No need to track gradients during inference
...         pred = net(x)  # Predict the denoised x0
...     pred_output_history.append(pred.detach().cpu())  # Store model output for plotting
...     mix_factor = 1 / (n_steps - i)  # How much we move towards the prediction
...     x = x * (1 - mix_factor) + pred * mix_factor  # Move part of the way there
...     step_history.append(x.detach().cpu())  # Store step for plotting

>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
...     axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
...     axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")

我们可以将过程分成更多的步骤,希望能得到更好的图像。

>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
...     noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps))  # Starting high going low
...     with torch.no_grad():
...         pred = net(x)
...     mix_factor = 1 / (n_steps - i)
...     x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")

效果不是很好,但能看到一些可识别的数字!你可以尝试训练更长时间(比如 10 或 20 个 epoch),并调整模型配置、学习率、优化器等。另外,别忘了,如果你想尝试一个稍微难一点的数据集,FashionMNIST 只需要一行代码就能替换。

与 DDPM 的比较

在这一节中,我们将看看我们的玩具级实现与另一个笔记本(Diffusers 简介)中使用的方法有何不同,后者是基于 DDPM 论文的。

我们将看到:

  • diffusers 的 UNet2DModel 比我们的 BasicUNet 要先进一些
  • 破坏过程的处理方式不同
  • 训练目标不同,涉及预测噪声而不是去噪后的图像
  • 模型通过时间步条件来适应噪声量,其中 t 作为额外的参数传递给前向方法。
  • 有多种不同的采样策略可用,它们应该比我们上面简单的版本效果更好。

自 DDPM 论文发表以来,已经提出了许多改进,但希望这个例子能有助于说明可用的不同设计决策。读完这部分后,你可能会喜欢深入研究论文 ‘Elucidating the Design Space of Diffusion-Based Generative Models’,它详细探讨了所有这些组件,并为如何获得最佳性能提出了新的建议。

如果所有这些内容对你来说太技术性或令人生畏,别担心!可以随意跳过本笔记本的其余部分,或者留到某个闲暇的日子再看。

UNet

diffusers 的 UNet2DModel 模型比我们上面基础的 UNet 有许多改进:

  • GroupNorm 对每个块的输入应用组归一化
  • Dropout 层以实现更平滑的训练
  • 每个块有多个 resnet 层(如果 layers_per_block 不设置为 1)
  • 注意力机制(通常仅在较低分辨率的块中使用)
  • 基于时间步的条件化。
  • 带有可学习参数的下采样和上采样块

让我们创建一个并检查一个 UNet2DModel

>>> model = UNet2DModel(
...     sample_size=28,  # the target image resolution
...     in_channels=1,  # the number of input channels, 3 for RGB images
...     out_channels=1,  # the number of output channels
...     layers_per_block=2,  # how many ResNet layers to use per UNet block
...     block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
...     down_block_types=(
...         "DownBlock2D",  # a regular ResNet downsampling block
...         "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
...         "AttnDownBlock2D",
...     ),
...     up_block_types=(
...         "AttnUpBlock2D",
...         "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
...         "UpBlock2D",  # a regular ResNet upsampling block
...     ),
... )
>>> print(model)
UNet2DModel(
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (1): AttnDownBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (2): AttnDownBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
  )
  (up_blocks): ModuleList(
    (0): AttnUpBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (1): AttnUpBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (2): UpBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (mid_block): UNetMidBlock2D(
    (attentions): ModuleList(
      (0): AttentionBlock(
        (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (proj_attn): Linear(in_features=64, out_features=64, bias=True)
      )
    )
    (resnets): ModuleList(
      (0): ResnetBlock2D(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU()
      )
      (1): ResnetBlock2D(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU()
      )
    )
  )
  (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

如你所见,这里的内容要多一些!它的参数也比我们的 BasicUNet 多得多。

sum([p.numel() for p in model.parameters()])  # 1.7M vs the ~309k parameters of the BasicUNet

我们可以用这个模型替代我们原来的模型来复现上面的训练过程。我们需要将 x 和 timestep 都传递给模型(这里我总是传递 t=0 来展示即使没有时间步条件它也能工作,并保持采样代码简单,但你也可以尝试传入 (amount*1000) 来从破坏量中获得一个等效的时间步)。如果你想检查代码,已更改的行用 #<<< 标记。

>>> # @markdown Trying UNet2DModel instead of BasicUNet:

>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

>>> # How many runs through the data should we do?
>>> n_epochs = 3

>>> # Create the network
>>> net = UNet2DModel(
...     sample_size=28,  # the target image resolution
...     in_channels=1,  # the number of input channels, 3 for RGB images
...     out_channels=1,  # the number of output channels
...     layers_per_block=2,  # how many ResNet layers to use per UNet block
...     block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
...     down_block_types=(
...         "DownBlock2D",  # a regular ResNet downsampling block
...         "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
...         "AttnDownBlock2D",
...     ),
...     up_block_types=(
...         "AttnUpBlock2D",
...         "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
...         "UpBlock2D",  # a regular ResNet upsampling block
...     ),
... )  # <<<
>>> net.to(device)

>>> # Our loss finction
>>> loss_fn = nn.MSELoss()

>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)

>>> # Keeping a record of the losses for later viewing
>>> losses = []

>>> # The training loop
>>> for epoch in range(n_epochs):

...     for x, y in train_dataloader:

...         # Get some data and prepare the corrupted version
...         x = x.to(device)  # Data on the GPU
...         noise_amount = torch.rand(x.shape[0]).to(device)  # Pick random noise amounts
...         noisy_x = corrupt(x, noise_amount)  # Create our noisy x

...         # Get the model prediction
...         pred = net(noisy_x, 0).sample  # <<< Using timestep 0 always, adding .sample

...         # Calculate the loss
...         loss = loss_fn(pred, x)  # How close is the output to the true 'clean' x?

...         # Backprop and update the params:
...         opt.zero_grad()
...         loss.backward()
...         opt.step()

...         # Store the loss for later
...         losses.append(loss.item())

...     # Print our the average of the loss values for this epoch:
...     avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
...     print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")

>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))

>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")

>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
...     noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps))  # Starting high going low
...     with torch.no_grad():
...         pred = net(x, 0).sample
...     mix_factor = 1 / (n_steps - i)
...     x = x * (1 - mix_factor) + pred * mix_factor

>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694

这看起来比我们第一组结果好多了!你可以尝试调整 unet 配置或训练更长时间以获得更好的性能。

破坏过程

DDPM 论文描述了一个在每个“时间步”添加少量噪声的破坏过程。给定某个时间步的 $x_{t-1}$,我们可以通过以下方式得到下一个(稍微更嘈杂)的版本 $x_t$:

$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$

也就是说,我们将 $x{t-1}$ 乘以 $\sqrt{1 - \beta_t}$ 并加上乘以 $\beta_t$ 的噪声。这个 $\beta$ 是根据某个调度为每个 t 定义的,它决定了每个时间步添加多少噪声。现在,我们不一定想为了得到 $x{500}$ 而进行 500 次这个操作,所以我们有另一个公式可以在给定 $x_0$ 的情况下得到任意 t 的 $x_t$:

$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ 其中 $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ 且 $\alpha_i = 1-\beta_i$

数学符号总是看起来很吓人!幸运的是,调度器为我们处理了所有这些(取消注释下一个单元格以查看代码)。我们可以绘制 $\sqrt{\bar{\alpha}_t}$(标记为 sqrt_alpha_prod)和 $\sqrt{(1 - \bar{\alpha}_t)}$(标记为 sqrt_one_minus_alpha_prod)来观察在不同时间步输入 (x) 和噪声是如何缩放和混合的。

# ??noise_scheduler.add_noise
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")

最初,带噪的 x 主要是 x(sqrt_alpha_prod ~= 1),但随着时间的推移,x 的贡献下降,噪声成分增加。与我们根据 amount 线性混合 x 和噪声不同,这个方法相对较快地变得嘈杂。我们可以在一些数据上可视化这一点。

>>> # @markdown visualize the DDPM noising process for different timesteps:

>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0  # Map to (-1, 1)
>>> print("X shape", xb.shape)

>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")

>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb)  # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)

>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

另一个动态是:DDPM 版本添加的是从高斯分布(均值为 0,标准差为 1,来自 `torch.randn`)中抽取的噪声,而不是我们在原始 `corrupt` 函数中使用的 0 到 1 之间的均匀噪声(来自 `torch.rand`)。总的来说,对训练数据进行归一化也是有意义的。在另一个笔记本中,你会在变换列表中看到 `Normalize(0.5, 0.5)`,它将图像数据从 (0, 1) 映射到 (-1, 1),这对我们的目的来说“足够好”了。我们在这个笔记本中没有这样做,但上面的可视化单元格中加入了它,以便进行更准确的缩放和可视化。

训练目标

在我们的玩具示例中,我们让模型尝试预测去噪后的图像。在 DDPM 和许多其他扩散模型的实现中,模型预测的是破坏过程中使用的噪声(在缩放之前,即单位方差噪声)。在代码中,它看起来像这样:

noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target

你可能会认为,预测噪声(从中我们可以推导出降噪后的图像是什么样子)等同于直接预测降噪后的图像。那么,为什么偏爱一种而不是另一种呢——仅仅是为了数学上的方便吗?

事实证明,这里还有另一个微妙之处。我们在训练过程中对不同(随机选择)的时间步计算损失。这些不同的目标将导致对这些损失进行不同的“隐式加权”,其中预测噪声会更侧重于较低的噪声水平。你可以选择更复杂的目标来改变这种“隐式损失加权”。或者,你可以选择一个噪声调度,使得在较高噪声水平下有更多的样本。或者,你可以让模型预测一个“速度”v,我们将其定义为依赖于噪声水平的图像和噪声的组合(参见 ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’)。或者,你可以让模型预测噪声,然后根据一些理论(参见 ‘Perception Prioritized Training of Diffusion Models’)或基于实验(参见 ‘Elucidating the Design Space of Diffusion-Based Generative Models’)来确定哪些噪声水平对模型最有信息量,从而根据噪声量用某个因子来缩放损失。总而言之:选择目标对模型性能有影响,关于什么是“最佳”选项的研究正在进行中。

目前,预测噪声(在某些地方你会看到 epsilon 或 eps)是首选方法,但随着时间的推移,我们可能会看到库中支持其他目标,并在不同情况下使用。

时间步条件

UNet2DModel 同时接收 x 和 timestep 作为输入。后者被转换成一个嵌入,并被输入到模型的多个位置。

这背后的理论是,通过给模型提供关于噪声水平的信息,它可以更好地执行其任务。虽然在没有这种时间步条件的情况下训练模型是可能的,但在某些情况下它确实有助于提高性能,并且大多数实现都包含了它,至少在当前的文献中是这样。

采样

给定一个能估计带噪输入中噪声(或预测去噪版本)的模型,我们如何生成新的图像?

我们可以输入纯噪声,并希望模型一步就能预测出一个好的图像作为去噪版本。然而,正如我们在上面的实验中看到的,这通常效果不佳。因此,我们采取一系列基于模型预测的小步骤,逐步地、一次去除一点点噪声。

具体如何采取这些步骤取决于所使用的采样方法。我们不会深入探讨理论,但一些关键的设计问题是:

  • 你应该采取多大的步长?换句话说,你应该遵循什么样的“噪声调度”?
  • 你是否只使用模型当前的预测来指导更新步骤(像 DDPM、DDIM 和许多其他方法)?你是否多次评估模型以估计更高阶的梯度,从而实现更大、更准确的步长(高阶方法和一些离散 ODE 求解器)?或者,你是否保留过去预测的历史记录,以更好地指导当前的更新步骤(线性多步法和祖先采样器)?
  • 你是否加入额外的噪声(有时称为 churn)来增加采样过程的随机性,还是保持其完全确定性?许多采样器通过一个参数(如 DDIM 采样器的‘eta’)来控制这一点,以便用户可以选择。

关于扩散模型采样方法的研究正在迅速发展,越来越多能够在更少步骤内找到好解的方法被提出来。勇敢和好奇的读者可能会有兴趣浏览 diffusers 库中不同实现的源代码这里,或查看文档,文档中通常会链接到相关的论文。

结论

希望这能帮助你从一个稍微不同的角度来看待扩散模型。

本笔记本由 Jonathan Whitaker 为 Hugging Face 课程编写,并与他自己的课程 ‘The Generative Landscape’ 中的一个版本有重叠。如果你想看到这个基础示例扩展到噪声和类别条件,可以去看看那个课程。问题或错误可以通过 GitHub issues 或 Discord 交流。也欢迎通过 Twitter @johnowhitaker 联系我。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.