从卷积的角度理解扩散原理

社区文章 发布于2025年1月16日

1. 什么是卷积?

1.1 从数学角度理解卷积

卷积的数学公式通常表示为两个函数 ( f(x) ) 和 ( g(x) ) 的卷积。它定义为:

(fg)(x)=f(t)g(xt)dt (f * g)(x) = \int_{-\infty}^{\infty} f(t) \cdot g(x - t) \, dt

其中:

  • ( f(x) ) 和 ( g(x) ) 是要进行卷积的两个函数。
  • ( (f * g)(x) ) 是卷积后的结果函数。
  • ( t ) 是积分变量。

对于离散卷积,公式为:

(fg)[n]=k=f[k]g[nk] (f * g)[n] = \sum_{k=-\infty}^{\infty} f[k] \cdot g[n - k]

这里,( f[k] ) 和 ( g[k] ) 是离散信号,( n ) 是离散输出索引。

1.2 卷积的可视化

在图像中,左侧显示了灰度图像的像素值矩阵(即图像如何以数字形式呈现给计算机)。中间是**卷积核**矩阵,它从左上角开始在原始图像上滑动。卷积核在每个位置计算一个值,并在图像中重复此过程。得到的值构成**右侧图像(特征图)**,其中包含通过卷积过程获得的原始图像的局部特征。

image/png

动画是这样运作的:

image/gif

2. 扩散模型原理

2.1 早期生成模型的原理

早期的生成模型,如GAN(生成对抗网络)和VAE(变分自编码器),涉及到原始模型的反演。例如,在GAN中,识别模型是用于识别生成图像的传统卷积网络。然而,生成模型通过使用转置卷积网络(也称为反卷积)来生成图像,从而反转了这一过程,但这种方法未能产生理想的结果。

我们来谈谈转置卷积。转置卷积是卷积的逆操作:卷积将大矩阵变为小矩阵,而转置卷积则将小矩阵生成为大矩阵。如下图所示,它创建了虚线矩阵!

image/gif

2.2 扩散模型

直接生成图像并不理想,因此科学家们从物理学中的扩散现象中获得了灵感。在自然界中,物质倾向于向无序状态扩散。例如,当一滴墨水滴入一杯水中时,它会逐渐扩散开来。这表明生成模型也可以采取渐进、循序渐进的方法,而不是急于求成,以期取得稳定进展。

因此,扩散模型应运而生。我们首先向图像像素中添加噪声,从而得到一个非常混乱的图像。反之,我们也可以反转这个过程,从这个嘈杂的图像中恢复原始图像。

image/png

2.3 扩散模型中的卷积

扩散模型通常使用UNet网络来预测去噪图像,并添加**时间步长**以反映噪声水平。预测是针对图像的每个**时间步长**进行的。

如图所示,这是扩散模型中使用的UNet网络中的一个卷积核(稍后将提供代码实现)。实际上,在整个网络中,卷积核的属性基本保持不变,并且在正向传播过程中输入的宽度和高度不会改变。只有通道数会改变。

image/png

我们记得,**卷积**将矩阵映射到特征矩阵,而**扩散**将无序引入矩阵。可以这样理解:**卷积**扰乱或恢复矩阵的局部特征,而**扩散**则依赖**卷积**来扩散局部特征。

image/png

3. 扩散模型代码实现

理论是一回事,让我们来看一个实际的例子。

3.1 导入所需库

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

3.2 使用MNIST数据集

dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

3.3 编写噪声破坏公式

破坏意味着将图像与噪声按一定比例混合以实现去噪。随着扩散过程的进行,图像变得更清晰,噪声的影响也更小。

def corrupt(x, amount):
    """Corrupt the input `x` by mixing it with noise according to `amount`"""
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)  # Adjust shape for broadcasting
    return x * (1 - amount) + noise * amount

3.4 创建一个简单的UNet模型

我们将使用一个迷你UNet模型(不是标准模型),它仍然能取得不错的R结果。

class BasicUNet(nn.Module):
    """A minimal UNet implementation."""

    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
            ]
        )
        self.up_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
            ]
        )
        self.act = nn.SiLU()  # Activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))  # Pass through the layer and activation function
            if i < 2:  # Skip connection for all but the final down layer
                h.append(x)  # Store output for skip connection
                x = self.downscale(x)  # Downscale for next layer

        for i, l in enumerate(self.up_layers):
            if i > 0:  # For all but the first up layer
                x = self.upscale(x)  # Upscale
                x += h.pop()  # Use stored output (skip connection)
            x = self.act(l(x))  # Pass through the layer and activation function

        return x

网络结构如下:

image/png

3.5 定义训练参数

# Dataloader (adjust batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Number of epochs
n_epochs = 30

# Create the network
net = UNet2DModel(
    sample_size=28,  # Target image resolution
    in_channels=1,  # Input channels, 3 for RGB images
    out_channels=1,  # Output channels
    layers_per_block=2,  # ResNet layers per UNet block
    block_out_channels=(32, 64, 64),  # Matching our basic UNet example
    down_block_types=(
        "DownBlock2D",  # Regular ResNet downsampling block
        "AttnDownBlock2D",  # ResNet block with attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # ResNet block with attention
        "UpBlock2D",  # Regular ResNet upsampling block
    ),
)
net.to(device)

# Loss function
loss_fn = nn.MSELoss()

# Optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

# Track losses
losses = []

# Training loop
for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)  # Data on GPU
        noise_amount = torch.rand(x.shape[0]).to(device)  # Random noise amount
        noisy_x = corrupt(x, noise_amount)  # Create noisy input

        # Get model prediction
        pred = net(noisy_x, 0).sample  # Use timestep 0

        # Calculate the loss
        loss = loss_fn(pred, x)  # Compare to original clean image

        # Backprop and update parameters
        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f"Epoch {epoch}. Average loss: {avg_loss:.5f}")

3.6 测试模型

训练完成后,我们可以测试扩散模型的强大能力。

n_steps = 40
x = torch.rand(8, 1, 28, 28).to(device)  # Start from random noise
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
    with torch.no_grad():  # No gradients during inference
        pred = net(x,0).sample  # Predict denoised image
    pred_output_history.append(pred.detach().cpu())  # Store prediction
    mix_factor = 1 / (n_steps - i)  # Mix towards prediction
    x = x * (1 - mix_factor) + pred * mix_factor  # Move partway to the denoised image
    step_history.append(x.detach().cpu())  # Store for plotting

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 32), sharex=True)
axs[0, 0].set_title("Input (noisy)")
axs[0, 1].set_title("Model Prediction")
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")

image/png

这展示了模型测试的最终结果,输出看起来相当令人印象深刻!

总结

至此,扩散模型的解释就结束了。你可以尝试修改UNet网络或调整参数,看看能否获得更出色的结果!

如果你觉得这篇博文有帮助,请考虑点赞 🤗。

社区

注册登录 评论