扩散模型课程文档

制作一个类条件扩散模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

制作一个类条件扩散模型

在这个 notebook 中,我们将演示一种向扩散模型添加条件信息的方法。具体来说,我们将在 MNIST 上训练一个类条件扩散模型,这是继单元 1 中的“从零开始”示例之后的内容。在推理时,我们可以指定希望模型生成哪个数字。

正如本单元介绍中提到的,这只是向扩散模型添加额外条件信息的众多方法之一,选择这种方法是因为它相对简单。就像单元 1 中的“从零开始”的 notebook 一样,这个 notebook 主要用于演示目的,如果你愿意,可以安全地跳过它。

设置和数据准备

>>> %pip install -q diffusers
     |████████████████████████████████| 503 kB 7.2 MB/s 
     |████████████████████████████████| 182 kB 51.3 MB/s 
[?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> from tqdm.auto import tqdm

>>> device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
>>> print(f"Using device: {device}")
Using device: cuda
>>> # Load the dataset
>>> dataset = torchvision.datasets.MNIST(
...     root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )

>>> # Feed it into a dataloader (batch size 8 here just for demo)
>>> train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

>>> # View some examples
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz

创建一个类条件 UNet

我们将通过以下方式输入类别条件:

  • 创建一个标准的 UNet2DModel,并增加一些额外的输入通道
  • 通过嵌入层将类别标签映射为一个形状为 (class_emb_size) 的学习向量
  • 使用 net_input = torch.cat((x, class_cond), 1) 将此信息作为额外通道与 UNet 的内部输入连接起来
  • 将这个 net_input(总共有 (class_emb_size+1) 个通道)输入到 UNet 中以获得最终预测

在这个例子中,我将 class_emb_size 设置为 4,但这完全是随意的,你可以探索将其设置为 1(看是否仍然有效)、10(与类别数量匹配),或者用类别标签的简单 one-hot 编码直接替换学习的 nn.Embedding。

这是实现的样子

class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=10, class_emb_size=4):
        super().__init__()

        # The embedding layer will map the class label to a vector of size class_emb_size
        self.class_emb = nn.Embedding(num_classes, class_emb_size)

        # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
        self.model = UNet2DModel(
            sample_size=28,  # the target image resolution
            in_channels=1 + class_emb_size,  # Additional input channels for class cond.
            out_channels=1,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(32, 64, 64),
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "AttnDownBlock2D",
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",  # a regular ResNet upsampling block
            ),
        )

    # Our forward method now takes the class labels as an additional argument
    def forward(self, x, t, class_labels):
        # Shape of x:
        bs, ch, w, h = x.shape

        # class conditioning in right shape to add as additional input channels
        class_cond = self.class_emb(class_labels)  # Map to embedding dimension
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)

        # Net input is now x and class cond concatenated together along dimension 1
        net_input = torch.cat((x, class_cond), 1)  # (bs, 5, 28, 28)

        # Feed this to the UNet alongside the timestep and return the prediction
        return self.model(net_input, t).sample  # (bs, 1, 28, 28)

如果任何形状或变换让你感到困惑,可以添加 print 语句来显示相关的形状,并检查它们是否符合你的预期。为了让事情更清晰,我还注释了一些中间变量的形状。

训练和采样

之前我们会做类似 prediction = unet(x, t) 的操作,现在我们会在训练时将正确的标签作为第三个参数加入(prediction = unet(x, t, y)),而在推理时,我们可以传递任何我们想要的标签,如果一切顺利,模型应该会生成匹配的图像。在这种情况下,y 是 MNIST 数字的标签,值为 0 到 9。

训练循环与单元 1 中的示例非常相似。我们现在预测的是噪声(而不是像单元 1 中那样预测去噪后的图像),以匹配默认的 DDPMScheduler 所期望的目标,我们用它在训练期间添加噪声并在推理时生成样本。训练需要一些时间——加快这个过程可能是一个有趣的小项目,但大多数人可能只需浏览代码(以及整个 notebook)而无需运行它,因为我们只是在阐述一个想法。

# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
>>> # @markdown Training loop (10 Epochs):

>>> # Redefining the dataloader to set the batch size higher than the demo of 8
>>> train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

>>> # How many runs through the data should we do?
>>> n_epochs = 10

>>> # Our network
>>> net = ClassConditionedUnet().to(device)

>>> # Our loss function
>>> loss_fn = nn.MSELoss()

>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)

>>> # Keeping a record of the losses for later viewing
>>> losses = []

>>> # The training loop
>>> for epoch in range(n_epochs):
...     for x, y in tqdm(train_dataloader):

...         # Get some data and prepare the corrupted version
...         x = x.to(device) * 2 - 1  # Data on the GPU (mapped to (-1, 1))
...         y = y.to(device)
...         noise = torch.randn_like(x)
...         timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
...         noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

...         # Get the model prediction
...         pred = net(noisy_x, timesteps, y)  # Note that we pass in the labels y

...         # Calculate the loss
...         loss = loss_fn(pred, noise)  # How close is the output to the noise

...         # Backprop and update the params:
...         opt.zero_grad()
...         loss.backward()
...         opt.step()

...         # Store the loss for later
...         losses.append(loss.item())

...     # Print out the average of the last 100 loss values to get an idea of progress:
...     avg_loss = sum(losses[-100:]) / 100
...     print(f"Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}")

>>> # View the loss curve
>>> plt.plot(losses)
Finished epoch 0. Average of the last 100 loss values: 0.052451

训练完成后,我们可以通过输入不同的标签作为条件来采样一些图像

>>> # @markdown Sampling some different digits:

>>> # Prepare random x to start from, plus some desired labels y
>>> x = torch.randn(80, 1, 28, 28).to(device)
>>> y = torch.tensor([[i] * 8 for i in range(10)]).flatten().to(device)

>>> # Sampling loop
>>> for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

...     # Get model pred
...     with torch.no_grad():
...         residual = net(x, t, y)  # Again, note that we pass in our labels y

...     # Update sample with step
...     x = noise_scheduler.step(residual, t, x).prev_sample

>>> # Show the results
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap="Greys")

就是这样!我们现在可以对生成的图像进行一些控制了。

希望你喜欢这个例子。与往常一样,欢迎在 Discord 中提问。

# Exercise (optional): Try this with FashionMNIST. Tweak the learning rate, batch size and number of epochs.
# Can you get some decent-looking fashion images with less training time than the example above?
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.