扩散模型课程文档
从零开始构建扩散模型
并获得增强的文档体验
开始使用
从零开始构建扩散模型
有时候,为了更好地理解一个事物的工作原理,研究其最简单的可能版本会很有帮助。在本笔记本中,我们将尝试这样做,从一个“玩具级”的扩散模型开始,看看不同部分是如何工作的,然后检查它们与更复杂实现有何不同。
我们将探讨:
- 破坏过程(向数据添加噪声)
- 什么是 UNet,以及如何从零开始实现一个极其精简的 UNet
- 扩散模型的训练
- 采样理论
然后,我们会将我们的版本与 diffusers 的 DDPM 实现进行比较,并探讨:
- 对我们迷你 UNet 的改进
- DDPM 的噪声调度
- 训练目标的差异
- 时间步条件
- 采样方法
这个笔记本内容相当深入,如果你对从零开始的深度探索不感兴趣,可以放心跳过!
同样值得注意的是,这里的大部分代码都是为了说明目的,我不建议直接在你的工作中使用它们(除非你只是为了学习目的而尝试改进这里展示的例子)。
设置与导入:
>>> %pip install -q diffusers
[K |████████████████████████████████| 255 kB 16.0 MB/s [K |████████████████████████████████| 163 kB 53.9 MB/s [?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")
Using device: cuda
数据
在这里,我们将使用一个非常小的数据集进行测试:mnist。如果你想在不改变其他任何设置的情况下给模型一个稍难的挑战,torchvision.datasets.FashionMNIST
应该可以作为直接的替代品。
>>> dataset = torchvision.datasets.MNIST(
... root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Input shape: torch.Size([8, 1, 28, 28]) Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每张图片是一幅 28x28 像素的灰度数字手写图,像素值范围从 0 到 1。
破坏过程
假设你没有读过任何关于扩散模型的论文,但你知道这个过程涉及到添加噪声。你会怎么做呢?
我们可能希望有一种简单的方法来控制破坏的程度。那么,如果我们引入一个参数 amount
来表示要添加的噪声量,然后这样做会怎么样?
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount
= 0,我们会得到没有任何改变的原始输入。如果 amount
达到 1,我们得到的是没有任何原始输入 x 痕迹的噪声。通过这种方式混合输入和噪声,我们能将输出保持在相同的范围(0 到 1)内。
我们可以相当容易地实现这一点(只需注意形状,以免被广播规则坑到):
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x * (1 - amount) + noise * amount
并通过视觉化结果来检查它是否按预期工作:
>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")
当噪声量接近 1 时,我们的数据开始看起来像纯粹的随机噪声。但对于大多数噪声量,你仍然可以相当准确地猜出数字。你认为这是最优的吗?
模型
我们希望有一个模型,它能接收 28 像素的带噪图像,并输出一个相同形状的预测结果。这里一个流行的选择是名为 UNet 的架构。UNet 最初是为医学影像中的分割任务而发明的,它包含一个“收缩路径”(数据在此路径上被压缩)和一个“扩展路径”(数据在此路径上恢复到原始维度,类似于自编码器),但它还具有跳跃连接,允许信息和梯度在不同层级之间流动。
一些 UNet 在每个阶段都具有复杂的模块,但对于这个玩具级演示,我们将构建一个极简的例子,它接收单通道图像,在下采样路径(图表和代码中的 down_layers)中通过三个卷积层,在上采样路径中通过三个卷积层,并且在下采样和上采样层之间有跳跃连接。我们将使用最大池化进行下采样,使用 nn.Upsample
进行上采样,而不是像更复杂的 UNet 那样依赖可学习的层。下面是大致的架构,显示了每一层输出的通道数:
这就是它在代码中的样子:
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList(
[
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
]
)
self.up_layers = torch.nn.ModuleList(
[
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
]
)
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
我们可以验证输出形状与输入相同,正如我们所期望的:
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
这个网络有超过 30 万个参数。
sum([p.numel() for p in net.parameters()])
如果你愿意,可以尝试改变每层的通道数或换用不同的架构。
训练网络
那么,这个模型具体应该做什么呢?同样,对此有不同的看法,但在这个演示中,我们选择一个简单的框架:给定一个被破坏的输入 noisy_x,模型应该输出它对原始 x 的最佳猜测。我们将通过均方误差(mean squared error)将这个猜测与真实值进行比较。
我们现在可以尝试训练网络了。
- 获取一批数据
- 以随机的量破坏它
- 将其输入模型
- 将模型预测与清晰图像进行比较,以计算我们的损失
- 相应地更新模型的参数。
你可以随意修改这个过程,看看能否让它工作得更好!
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)
>>> # Our loss function
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x)
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)
Finished epoch 0. Average loss for this epoch: 0.026736 Finished epoch 1. Average loss for this epoch: 0.020692 Finished epoch 2. Average loss for this epoch: 0.018887
我们可以通过取一批数据,以不同的量破坏它,然后观察模型的预测结果,来看看模型的表现如何。
>>> # @markdown Visualizing model predictions on noisy inputs:
>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8] # Only using the first 8 for easy plotting
>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Get the model predictions
>>> with torch.no_grad():
... preds = net(noised_x.to(device)).detach().cpu()
>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")
你可以看到,对于较低的噪声量,预测结果相当不错!但随着噪声水平变得非常高,模型可利用的信息越来越少,当噪声量达到 1 时,它会输出一个接近数据集均值的模糊图像,试图在不确定的情况下做出最稳妥的猜测……
采样
如果我们在高噪声水平下的预测不是很好,我们该如何生成图像呢?
嗯,如果我们从随机噪声开始,观察模型的预测,但只朝着那个预测移动一小部分——比如说,20% 的距离。现在我们有了一张非常嘈杂的图像,其中可能带有一丝结构,我们可以将其输入模型以获得新的预测。希望这个新的预测比第一个稍好一些(因为我们的起点噪声稍小),这样我们就可以用这个新的、更好的预测再迈出一小步。
重复几次,如果一切顺利,我们就能得到一张图像!这里展示了这个过程仅用 5 个步骤的图示,可视化了每个阶段模型的输入(左)和预测的去噪图像(右)。请注意,即使模型在第一步就预测了去噪图像,我们也只让 x 朝那个方向移动了一部分。经过几步,结构出现并得到完善,直到我们得到最终的输出。
>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device) # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []
>>> for i in range(n_steps):
... with torch.no_grad(): # No need to track gradients during inference
... pred = net(x) # Predict the denoised x0
... pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
... mix_factor = 1 / (n_steps - i) # How much we move towards the prediction
... x = x * (1 - mix_factor) + pred * mix_factor # Move part of the way there
... step_history.append(x.detach().cpu()) # Store step for plotting
>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
... axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
... axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")
我们可以将过程分成更多的步骤,希望能得到更好的图像。
>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x)
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
效果不是很好,但能看到一些可识别的数字!你可以尝试训练更长时间(比如 10 或 20 个 epoch),并调整模型配置、学习率、优化器等。另外,别忘了,如果你想尝试一个稍微难一点的数据集,FashionMNIST 只需要一行代码就能替换。
与 DDPM 的比较
在这一节中,我们将看看我们的玩具级实现与另一个笔记本(Diffusers 简介)中使用的方法有何不同,后者是基于 DDPM 论文的。
我们将看到:
- diffusers 的
UNet2DModel
比我们的 BasicUNet 要先进一些 - 破坏过程的处理方式不同
- 训练目标不同,涉及预测噪声而不是去噪后的图像
- 模型通过时间步条件来适应噪声量,其中 t 作为额外的参数传递给前向方法。
- 有多种不同的采样策略可用,它们应该比我们上面简单的版本效果更好。
自 DDPM 论文发表以来,已经提出了许多改进,但希望这个例子能有助于说明可用的不同设计决策。读完这部分后,你可能会喜欢深入研究论文 ‘Elucidating the Design Space of Diffusion-Based Generative Models’,它详细探讨了所有这些组件,并为如何获得最佳性能提出了新的建议。
如果所有这些内容对你来说太技术性或令人生畏,别担心!可以随意跳过本笔记本的其余部分,或者留到某个闲暇的日子再看。
UNet
diffusers 的 UNet2DModel 模型比我们上面基础的 UNet 有许多改进:
- GroupNorm 对每个块的输入应用组归一化
- Dropout 层以实现更平滑的训练
- 每个块有多个 resnet 层(如果
layers_per_block
不设置为 1) - 注意力机制(通常仅在较低分辨率的块中使用)
- 基于时间步的条件化。
- 带有可学习参数的下采样和上采样块
让我们创建一个并检查一个 UNet2DModel
>>> model = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... )
>>> print(model)
UNet2DModel( (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_proj): Timesteps() (time_embedding): TimestepEmbedding( (linear_1): Linear(in_features=32, out_features=128, bias=True) (act): SiLU() (linear_2): Linear(in_features=128, out_features=128, bias=True) ) (down_blocks): ModuleList( (0): DownBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (1): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (2): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) ) (up_blocks): ModuleList( (0): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (1): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (2): UpBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) ) ) ) (mid_block): UNetMidBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True) (conv_act): SiLU() (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )
如你所见,这里的内容要多一些!它的参数也比我们的 BasicUNet 多得多。
sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet
我们可以用这个模型替代我们原来的模型来复现上面的训练过程。我们需要将 x 和 timestep 都传递给模型(这里我总是传递 t=0 来展示即使没有时间步条件它也能工作,并保持采样代码简单,但你也可以尝试传入 (amount*1000)
来从破坏量中获得一个等效的时间步)。如果你想检查代码,已更改的行用 #<<<
标记。
>>> # @markdown Trying UNet2DModel instead of BasicUNet:
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... ) # <<<
>>> net.to(device)
>>> # Our loss finction
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x, 0).sample # <<< Using timestep 0 always, adding .sample
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))
>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")
>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x, 0).sample
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")
Finished epoch 0. Average loss for this epoch: 0.018925 Finished epoch 1. Average loss for this epoch: 0.012785 Finished epoch 2. Average loss for this epoch: 0.011694
这看起来比我们第一组结果好多了!你可以尝试调整 unet 配置或训练更长时间以获得更好的性能。
破坏过程
DDPM 论文描述了一个在每个“时间步”添加少量噪声的破坏过程。给定某个时间步的 $x_{t-1}$,我们可以通过以下方式得到下一个(稍微更嘈杂)的版本 $x_t$:
$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$
也就是说,我们将 $x{t-1}$ 乘以 $\sqrt{1 - \beta_t}$ 并加上乘以 $\beta_t$ 的噪声。这个 $\beta$ 是根据某个调度为每个 t 定义的,它决定了每个时间步添加多少噪声。现在,我们不一定想为了得到 $x{500}$ 而进行 500 次这个操作,所以我们有另一个公式可以在给定 $x_0$ 的情况下得到任意 t 的 $x_t$:
$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ 其中 $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ 且 $\alpha_i = 1-\beta_i$
数学符号总是看起来很吓人!幸运的是,调度器为我们处理了所有这些(取消注释下一个单元格以查看代码)。我们可以绘制 $\sqrt{\bar{\alpha}_t}$(标记为 sqrt_alpha_prod
)和 $\sqrt{(1 - \bar{\alpha}_t)}$(标记为 sqrt_one_minus_alpha_prod
)来观察在不同时间步输入 (x) 和噪声是如何缩放和混合的。
# ??noise_scheduler.add_noise
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")
最初,带噪的 x 主要是 x(sqrt_alpha_prod ~= 1),但随着时间的推移,x 的贡献下降,噪声成分增加。与我们根据 amount
线性混合 x 和噪声不同,这个方法相对较快地变得嘈杂。我们可以在一些数据上可视化这一点。
>>> # @markdown visualize the DDPM noising process for different timesteps:
>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0 # Map to (-1, 1)
>>> print("X shape", xb.shape)
>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")
>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb) # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)
>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")
X shape torch.Size([8, 1, 28, 28]) Noisy X shape torch.Size([8, 1, 28, 28])
另一个动态是:DDPM 版本添加的是从高斯分布(均值为 0,标准差为 1,来自 `torch.randn`)中抽取的噪声,而不是我们在原始 `corrupt` 函数中使用的 0 到 1 之间的均匀噪声(来自 `torch.rand`)。总的来说,对训练数据进行归一化也是有意义的。在另一个笔记本中,你会在变换列表中看到 `Normalize(0.5, 0.5)`,它将图像数据从 (0, 1) 映射到 (-1, 1),这对我们的目的来说“足够好”了。我们在这个笔记本中没有这样做,但上面的可视化单元格中加入了它,以便进行更准确的缩放和可视化。
训练目标
在我们的玩具示例中,我们让模型尝试预测去噪后的图像。在 DDPM 和许多其他扩散模型的实现中,模型预测的是破坏过程中使用的噪声(在缩放之前,即单位方差噪声)。在代码中,它看起来像这样:
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target
你可能会认为,预测噪声(从中我们可以推导出降噪后的图像是什么样子)等同于直接预测降噪后的图像。那么,为什么偏爱一种而不是另一种呢——仅仅是为了数学上的方便吗?
事实证明,这里还有另一个微妙之处。我们在训练过程中对不同(随机选择)的时间步计算损失。这些不同的目标将导致对这些损失进行不同的“隐式加权”,其中预测噪声会更侧重于较低的噪声水平。你可以选择更复杂的目标来改变这种“隐式损失加权”。或者,你可以选择一个噪声调度,使得在较高噪声水平下有更多的样本。或者,你可以让模型预测一个“速度”v,我们将其定义为依赖于噪声水平的图像和噪声的组合(参见 ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’)。或者,你可以让模型预测噪声,然后根据一些理论(参见 ‘Perception Prioritized Training of Diffusion Models’)或基于实验(参见 ‘Elucidating the Design Space of Diffusion-Based Generative Models’)来确定哪些噪声水平对模型最有信息量,从而根据噪声量用某个因子来缩放损失。总而言之:选择目标对模型性能有影响,关于什么是“最佳”选项的研究正在进行中。
目前,预测噪声(在某些地方你会看到 epsilon 或 eps)是首选方法,但随着时间的推移,我们可能会看到库中支持其他目标,并在不同情况下使用。
时间步条件
UNet2DModel 同时接收 x 和 timestep 作为输入。后者被转换成一个嵌入,并被输入到模型的多个位置。
这背后的理论是,通过给模型提供关于噪声水平的信息,它可以更好地执行其任务。虽然在没有这种时间步条件的情况下训练模型是可能的,但在某些情况下它确实有助于提高性能,并且大多数实现都包含了它,至少在当前的文献中是这样。
采样
给定一个能估计带噪输入中噪声(或预测去噪版本)的模型,我们如何生成新的图像?
我们可以输入纯噪声,并希望模型一步就能预测出一个好的图像作为去噪版本。然而,正如我们在上面的实验中看到的,这通常效果不佳。因此,我们采取一系列基于模型预测的小步骤,逐步地、一次去除一点点噪声。
具体如何采取这些步骤取决于所使用的采样方法。我们不会深入探讨理论,但一些关键的设计问题是:
- 你应该采取多大的步长?换句话说,你应该遵循什么样的“噪声调度”?
- 你是否只使用模型当前的预测来指导更新步骤(像 DDPM、DDIM 和许多其他方法)?你是否多次评估模型以估计更高阶的梯度,从而实现更大、更准确的步长(高阶方法和一些离散 ODE 求解器)?或者,你是否保留过去预测的历史记录,以更好地指导当前的更新步骤(线性多步法和祖先采样器)?
- 你是否加入额外的噪声(有时称为 churn)来增加采样过程的随机性,还是保持其完全确定性?许多采样器通过一个参数(如 DDIM 采样器的‘eta’)来控制这一点,以便用户可以选择。
关于扩散模型采样方法的研究正在迅速发展,越来越多能够在更少步骤内找到好解的方法被提出来。勇敢和好奇的读者可能会有兴趣浏览 diffusers 库中不同实现的源代码这里,或查看文档,文档中通常会链接到相关的论文。
结论
希望这能帮助你从一个稍微不同的角度来看待扩散模型。
本笔记本由 Jonathan Whitaker 为 Hugging Face 课程编写,并与他自己的课程 ‘The Generative Landscape’ 中的一个版本有重叠。如果你想看到这个基础示例扩展到噪声和类别条件,可以去看看那个课程。问题或错误可以通过 GitHub issues 或 Discord 交流。也欢迎通过 Twitter @johnowhitaker 联系我。
< > 在 GitHub 上更新