从零开始的扩散模型
有时,考虑事物的最简单版本有助于更好地理解其工作原理。我们将在本笔记本中尝试这样做,首先从一个“玩具”扩散模型开始,了解不同部分是如何工作的,然后研究它们与更复杂实现的不同之处。
我们将着眼于
- 损坏过程(向数据添加噪声)
- 什么是 U-Net,以及如何从零开始实现一个极简的 U-Net
- 扩散模型训练
- 采样理论
然后,我们将我们的版本与 diffusers 的 DDPM 实现进行比较,探讨
- 我们迷你 U-Net 的改进
- DDPM 噪声调度
- 训练目标的差异
- 时间步长条件控制
- 采样方法
此笔记本内容相当深入,如果您对从头开始的深入研究不感兴趣,可以跳过!
还值得注意的是,此处的大多数代码仅用于说明目的,我不建议直接将任何代码用于您自己的工作(除非您只是为了学习目的尝试改进此处显示的示例)。
设置和导入:
>>> %pip install -q diffusers
[K |████████████████████████████████| 255 kB 16.0 MB/s [K |████████████████████████████████| 163 kB 53.9 MB/s [?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")
Using device: cuda
数据
在这里,我们将使用一个非常小的数据集:mnist 进行测试。如果您想让模型面临稍微更难的挑战,而无需更改任何其他内容,torchvision.datasets.FashionMNIST 应该可以作为直接替换使用。
>>> dataset = torchvision.datasets.MNIST(
... root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Input shape: torch.Size([8, 1, 28, 28]) Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每个图像都是一个灰度 28 像素 x 28 像素的数字图像,其值范围从 0 到 1。
损坏过程
假设您还没有阅读任何扩散模型论文,但您知道该过程涉及添加噪声。您将如何做到这一点?
我们可能希望有一种简单的方法来控制损坏量。那么,如果我们为要添加的噪声的 amount
量输入一个参数,然后我们执行
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount = 0,我们将获得没有任何更改的输入。如果 amount 达到 1,我们将获得没有输入 x 痕迹的噪声。通过以这种方式将输入与噪声混合,我们使输出保持在相同的范围内(0 到 1)。
我们可以很容易地实现这一点(只需注意形状,以免被广播规则烧伤)
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x * (1 - amount) + noise * amount
并直观地查看结果,以确保其按预期工作
>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")
随着噪声量接近 1,我们的数据开始看起来像纯粹的随机噪声。但是对于大多数噪声量,您可以很好地猜测数字。您认为这是最佳的吗?
模型
我们希望有一个模型,它可以接收 28 像素的噪声图像并输出相同形状的预测。这里一个流行的选择是一种称为 U-Net 的架构。 最初发明用于医学影像中的分割任务,U-Net 由一条“收缩路径”组成,数据通过该路径被压缩,以及一条“扩展路径”,数据通过该路径扩展回原始维度(类似于自动编码器),但也具有跳跃连接,允许信息和梯度在不同级别流动。
一些 U-Net 在每个阶段都具有复杂的块,但对于这个玩具演示,我们将构建一个最小的示例,它接收一个单通道图像并在向下路径上通过三个卷积层(图和代码中的 down_layers)和三个向上路径上的层,并在向下层和向上层之间具有跳跃连接。我们将使用最大池化进行下采样,并使用 nn.Upsample
进行上采样,而不是依赖于像更复杂的 U-Net 那样的可学习层。以下是显示每个层输出通道数的粗略架构
这就是它在代码中的样子
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList(
[
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
]
)
self.up_layers = torch.nn.ModuleList(
[
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
]
)
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
我们可以验证输出形状与输入相同,正如我们预期的那样
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
这个网络有超过 300,000 个参数
sum([p.numel() for p in net.parameters()])
如果您愿意,可以探索更改每一层中的通道数量或换入不同的架构。
训练网络
那么模型到底应该做什么呢?同样地,对此有很多不同的理解,但对于这个演示,让我们选择一个简单的框架:给定一个损坏的输入 noisy_x,模型应该输出其对原始 x 的最佳猜测。我们将通过均方误差将其与实际值进行比较。
现在我们可以尝试训练网络了。
- 获取一批数据
- 通过随机量对其进行损坏
- 将其馈送到模型中
- 将模型预测与干净图像进行比较以计算损失
- 相应地更新模型的参数。
随意修改它,看看能否让它工作得更好!
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)
>>> # Our loss function
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x)
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)
Finished epoch 0. Average loss for this epoch: 0.026736 Finished epoch 1. Average loss for this epoch: 0.020692 Finished epoch 2. Average loss for this epoch: 0.018887
我们可以尝试通过获取一批数据、以不同的量对其进行损坏,然后查看模型的预测来查看模型预测的样子
>>> # @markdown Visualizing model predictions on noisy inputs:
>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8] # Only using the first 8 for easy plotting
>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Get the model predictions
>>> with torch.no_grad():
... preds = net(noised_x.to(device)).detach().cpu()
>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")
您可以看到,对于较低的量,预测非常不错!但是随着级别的提高,模型可以利用的东西越来越少,当我们达到 amount=1 时,它会输出一个模糊的混乱,接近数据集的平均值,以试图对其输出可能是什么样子进行对冲……
采样
如果我们在高噪声水平下的预测不太好,我们如何生成图像呢?
好吧,如果我们从随机噪声开始,查看模型预测,但只朝该预测移动一小部分——比如 20% 的距离。现在我们有一个非常嘈杂的图像,其中可能包含结构的提示,我们可以将其馈送到模型以获得新的预测。希望这个新的预测比第一个预测稍微好一点(因为我们的起点稍微不那么嘈杂),所以我们可以用这个新的、更好的预测再迈出一小步。
重复几次,如果一切顺利,我们就会得到一张图像!以下是该过程在仅 5 个步骤中说明的情况,可视化每个阶段的模型输入(左侧)和预测的去噪图像(右侧)。请注意,即使模型在步骤 1 处预测去噪图像,我们也只向那里移动 x 部分的距离。经过几个步骤,结构出现并得到细化,直到我们获得最终输出。
>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device) # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []
>>> for i in range(n_steps):
... with torch.no_grad(): # No need to track gradients during inference
... pred = net(x) # Predict the denoised x0
... pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
... mix_factor = 1 / (n_steps - i) # How much we move towards the prediction
... x = x * (1 - mix_factor) + pred * mix_factor # Move part of the way there
... step_history.append(x.detach().cpu()) # Store step for plotting
>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
... axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
... axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")
我们可以将过程分成更多步骤,并希望以此获得更好的图像
>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x)
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
还不错,但那里有一些可识别的数字!您可以尝试更长时间的训练(例如,10 或 20 个 epochs)并调整模型配置、学习率、优化器等。此外,不要忘记,如果您想要一个稍微更难的数据集来尝试,fashionMNIST 可以一键替换。
与 DDPM 的比较
在本节中,我们将了解我们的玩具实现与另一个笔记本中使用的方法(Diffusers 简介)的不同之处,该方法基于 DDPM 论文。
我们将看到
- diffusers 的
UNet2DModel
比我们的 BasicUNet 更高级 - 损坏过程的处理方式不同
- 训练目标不同,涉及预测噪声而不是去噪图像
- 模型以时间步长调节为条件,其中 t 作为附加参数传递到前向方法。
- 有许多不同的采样策略可用,它们应该比我们上面简化的版本效果更好。
自从 DDPM 论文发表以来,已经提出了一些改进建议,但此示例希望能够说明不同的可用设计决策。阅读完本文后,您可能会喜欢深入研究论文 “阐明基于扩散的生成模型的设计空间”,该论文详细探讨了所有这些组件,并对如何获得最佳性能提出了新的建议。
如果所有这些都过于技术性或令人生畏,请不要担心!随意跳过本笔记本的其余部分或将其保存起来以备不时之需。
U-Net
diffusers 的 UNet2DModel 模型相较于我们上面基本的 U-Net 有许多改进
- GroupNorm 将组归一化应用于每个块的输入
- 用于更平滑训练的 Dropout 层
- 每个块的多个 ResNet 层(如果 layers_per_block 未设置为 1)
- 注意力(通常仅在较低分辨率块中使用)
- 以时间步长为条件。
- 具有可学习参数的下采样和上采样块
让我们创建一个并检查 UNet2DModel
>>> model = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... )
>>> print(model)
UNet2DModel( (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_proj): Timesteps() (time_embedding): TimestepEmbedding( (linear_1): Linear(in_features=32, out_features=128, bias=True) (act): SiLU() (linear_2): Linear(in_features=128, out_features=128, bias=True) ) (down_blocks): ModuleList( (0): DownBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (1): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 32, eps=1e-05, affine=True) (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) (downsamplers): ModuleList( (0): Downsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) ) ) ) (2): AttnDownBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) ) (up_blocks): ModuleList( (0): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (1): AttnUpBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (1): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) (2): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-05, affine=True) (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1)) ) ) (upsamplers): ModuleList( (0): Upsample2D( (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) ) (2): UpBlock2D( (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 96, eps=1e-05, affine=True) (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1)) ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) (2): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=32, bias=True) (norm2): GroupNorm(32, 32, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1)) ) ) ) ) (mid_block): UNetMidBlock2D( (attentions): ModuleList( (0): AttentionBlock( (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True) (query): Linear(in_features=64, out_features=64, bias=True) (key): Linear(in_features=64, out_features=64, bias=True) (value): Linear(in_features=64, out_features=64, bias=True) (proj_attn): Linear(in_features=64, out_features=64, bias=True) ) ) (resnets): ModuleList( (0): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) (1): ResnetBlock2D( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (time_emb_proj): Linear(in_features=128, out_features=64, bias=True) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() ) ) ) (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True) (conv_act): SiLU() (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )
如您所见,发生的事情更多了!它也比我们的 BasicUNet 拥有更多参数
sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet
我们可以使用此模型代替原始模型来复制上面显示的训练。我们需要将 x 和 timestep 都传递给模型(这里我始终传递 t=0 以表明它在没有此时间步长调节的情况下也能工作,并使采样代码保持简单,但您也可以尝试馈送 (amount*1000)
以从损坏量获得等效的时间步长)。如果要检查代码,则已更改的行将显示为 #<<<
。
>>> # @markdown Trying UNet2DModel instead of BasicUNet:
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... ) # <<<
>>> net.to(device)
>>> # Our loss finction
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x, 0).sample # <<< Using timestep 0 always, adding .sample
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))
>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")
>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x, 0).sample
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")
Finished epoch 0. Average loss for this epoch: 0.018925 Finished epoch 1. Average loss for this epoch: 0.012785 Finished epoch 2. Average loss for this epoch: 0.011694
这看起来比我们的第一组结果好多了!您可以尝试调整 U-Net 配置或更长时间地训练以获得更好的性能。
损坏过程
DDPM 论文描述了一个损坏过程,该过程为每个“时间步长”添加少量噪声。给定某个时间步长的 $x_{t-1}$,我们可以通过以下方式获得下一个(稍微更嘈杂)版本 $x_t$:
$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$
也就是说,我们取 $x{t-1}$,将其按 $\sqrt{1 - \beta_t}$ 的比例缩放,并添加按 $\beta_t$ 的比例缩放的噪声。这个 $\beta$ 是根据某个时间表为每个 t 定义的,并确定每个时间步长添加多少噪声。现在,我们不一定希望执行此操作 500 次以获得 $x{500}$,因此我们有另一个公式来根据 $x_0$ 获取任何 t 的 $x_t$
$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ 其中 $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ 且 $\alpha_i = 1-\beta_i$
数学符号总是看起来很吓人!幸运的是,调度程序为我们处理了所有这些(取消注释下一个单元格以查看代码)。我们可以绘制 $\sqrt{\bar{\alpha}_t}$(标记为 sqrt_alpha_prod
)和 $\sqrt{(1 - \bar{\alpha}_t)}$(标记为 sqrt_one_minus_alpha_prod
)以查看输入 (x) 和噪声如何在不同时间步长中进行缩放和混合
# ??noise_scheduler.add_noise
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")
最初,嘈杂的 x 主要为 x(sqrt_alpha_prod ~= 1),但随着时间的推移,x 的贡献下降,而噪声成分增加。与根据 amount
对 x 和噪声进行线性混合不同,这种噪声相对较快。我们可以在一些数据上将其可视化
>>> # @markdown visualize the DDPM noising process for different timesteps:
>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0 # Map to (-1, 1)
>>> print("X shape", xb.shape)
>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")
>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb) # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)
>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")
X shape torch.Size([8, 1, 28, 28]) Noisy X shape torch.Size([8, 1, 28, 28])
另一个起作用的动态:DDPM 版本添加了从高斯分布中提取的噪声(均值为 0,标准差为 1 来自 torch.randn),而不是我们最初的 corrupt
函数中使用的 0 到 1 之间的均匀噪声(来自 torch.rand)。通常,对训练数据进行归一化也是有意义的。在另一个笔记本中,您将在转换列表中看到 Normalize(0.5, 0.5)
,它将图像数据从 (0, 1) 映射到 (-1, 1),并且对于我们的目的来说“足够好”。我们没有在本笔记本中这样做,但上面的可视化单元格将其添加进来以实现更准确的缩放和可视化。
训练目标
在我们的玩具示例中,我们让模型尝试预测去噪图像。在 DDPM 和许多其他扩散模型实现中,模型预测损坏过程中使用的噪声(缩放前,因此为单位方差噪声)。在代码中,它看起来像这样
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target
你可能认为预测噪声(从中我们可以推导出去噪后的图像是什么样子)等同于直接预测去噪后的图像。那么,为什么偏爱其中一个而不是另一个——仅仅是为了数学上的方便吗?
事实证明,这里还有一个微妙之处。我们在训练期间跨越不同的(随机选择的)时间步计算损失。这些不同的目标将导致这些损失的不同“隐式加权”,其中预测噪声会更多地关注较低的噪声水平。你可以选择更复杂的目标来改变这种“隐式损失加权”。或者,你可能选择一个噪声调度,它将在更高的噪声水平下产生更多示例。也许你让模型预测一个“速度”v,我们将其定义为图像和噪声的组合,具体取决于噪声水平(参见“用于快速采样扩散模型的渐进蒸馏”)。也许你让模型预测噪声,然后根据一些理论(参见“扩散模型的感知优先训练”)或根据尝试查看哪些噪声水平对模型最有信息量的实验(参见“阐明基于扩散的生成模型的设计空间”)将损失按某个取决于噪声量的因子进行缩放。TL;DR:选择目标会影响模型性能,并且正在进行关于“最佳”选项的研究。
目前,预测噪声(在某些地方你会看到 epsilon 或 eps)是首选方法,但随着时间的推移,我们可能会看到库中支持其他目标并在不同情况下使用。
时间步条件
UNet2DModel 接收 x 和时间步作为输入。后者被转换为嵌入并馈送到模型的多个位置。
背后的理论是,通过向模型提供有关噪声水平的信息,它可以更好地执行其任务。虽然可以训练没有这种时间步条件的模型,但在某些情况下它确实有助于提高性能,并且大多数实现都包含它,至少在当前文献中是这样。
采样
给定一个估计噪声输入中存在噪声(或预测去噪版本)的模型,我们如何生成新的图像?
我们可以输入纯噪声,并希望模型在一步骤中预测一个好的图像作为去噪版本。但是,正如我们在上面的实验中看到的,这通常效果不好。因此,我们根据模型预测采取若干个较小的步骤,每次迭代地去除少量噪声。
我们如何采取这些步骤完全取决于所使用的采样方法。我们不会深入研究理论,但一些关键的设计问题是
- 你应该采取多大的步骤?换句话说,你应该遵循什么“噪声调度”?
- 你是否只使用模型的当前预测来告知更新步骤(如 DDPM、DDIM 和许多其他方法)?你是否评估模型几次以估计更高阶梯度以获得更大、更准确的步骤(高阶方法和一些离散 ODE 求解器)?或者你是否保留过去预测的历史记录以尝试更好地告知当前更新步骤(线性多步和祖先采样器)?
- 你是否添加了额外的噪声(有时称为 churn)以向采样过程添加更多随机性,或者你是否将其保持完全确定性?许多采样器使用参数(例如 DDIM 采样器的“eta”)来控制这一点,以便用户可以选择。
扩散模型的采样方法的研究正在迅速发展,并且提出了越来越多的在更少步骤中找到良好解决方案的方法。勇敢而好奇的人可能会发现浏览 diffusers 库中提供的不同实现的代码很有趣这里或查看文档,这些文档通常会链接到相关的论文。
结论
希望这能成为一种从稍微不同的角度看待扩散模型的有用方法。
此笔记本由 Jonathan Whitaker 为此 Hugging Face 课程编写,与他自己的课程中包含的版本“生成景观”重叠。如果你想看到此基本示例扩展到噪声和类别条件,请查看该课程。可以通过 GitHub 问题或 Discord 传达问题或错误。你也可以通过 Twitter@johnowhitaker 联系我们。