从卷积的角度理解扩散原理
1. 什么是卷积?
1.1 从数学角度理解卷积
卷积的数学公式通常表示为两个函数 ( f(x) ) 和 ( g(x) ) 的卷积。它定义为:
其中:
- ( f(x) ) 和 ( g(x) ) 是要进行卷积的两个函数。
- ( (f * g)(x) ) 是卷积后的结果函数。
- ( t ) 是积分变量。
对于离散卷积,公式为:
这里,( f[k] ) 和 ( g[k] ) 是离散信号,( n ) 是离散输出索引。
1.2 卷积的可视化
在图像中,左侧显示了灰度图像的像素值矩阵(即图像如何以数字形式呈现给计算机)。中间是**卷积核**矩阵,它从左上角开始在原始图像上滑动。卷积核在每个位置计算一个值,并在图像中重复此过程。得到的值构成**右侧图像(特征图)**,其中包含通过卷积过程获得的原始图像的局部特征。
动画是这样运作的:
2. 扩散模型原理
2.1 早期生成模型的原理
早期的生成模型,如GAN(生成对抗网络)和VAE(变分自编码器),涉及到原始模型的反演。例如,在GAN中,识别模型是用于识别生成图像的传统卷积网络。然而,生成模型通过使用转置卷积网络(也称为反卷积)来生成图像,从而反转了这一过程,但这种方法未能产生理想的结果。
我们来谈谈转置卷积。转置卷积是卷积的逆操作:卷积将大矩阵变为小矩阵,而转置卷积则将小矩阵生成为大矩阵。如下图所示,它创建了虚线矩阵!
2.2 扩散模型
直接生成图像并不理想,因此科学家们从物理学中的扩散现象中获得了灵感。在自然界中,物质倾向于向无序状态扩散。例如,当一滴墨水滴入一杯水中时,它会逐渐扩散开来。这表明生成模型也可以采取渐进、循序渐进的方法,而不是急于求成,以期取得稳定进展。
因此,扩散模型应运而生。我们首先向图像像素中添加噪声,从而得到一个非常混乱的图像。反之,我们也可以反转这个过程,从这个嘈杂的图像中恢复原始图像。
2.3 扩散模型中的卷积
扩散模型通常使用UNet网络来预测去噪图像,并添加**时间步长**以反映噪声水平。预测是针对图像的每个**时间步长**进行的。
如图所示,这是扩散模型中使用的UNet网络中的一个卷积核(稍后将提供代码实现)。实际上,在整个网络中,卷积核的属性基本保持不变,并且在正向传播过程中输入的宽度和高度不会改变。只有通道数会改变。
我们记得,**卷积**将矩阵映射到特征矩阵,而**扩散**将无序引入矩阵。可以这样理解:**卷积**扰乱或恢复矩阵的局部特征,而**扩散**则依赖**卷积**来扩散局部特征。
3. 扩散模型代码实现
理论是一回事,让我们来看一个实际的例子。
3.1 导入所需库
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
3.2 使用MNIST数据集
dataset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
3.3 编写噪声破坏公式
破坏意味着将图像与噪声按一定比例混合以实现去噪。随着扩散过程的进行,图像变得更清晰,噪声的影响也更小。
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Adjust shape for broadcasting
return x * (1 - amount) + noise * amount
3.4 创建一个简单的UNet模型
我们将使用一个迷你UNet模型(不是标准模型),它仍然能取得不错的R结果。
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList(
[
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
]
)
self.up_layers = torch.nn.ModuleList(
[
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
]
)
self.act = nn.SiLU() # Activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Pass through the layer and activation function
if i < 2: # Skip connection for all but the final down layer
h.append(x) # Store output for skip connection
x = self.downscale(x) # Downscale for next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all but the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Use stored output (skip connection)
x = self.act(l(x)) # Pass through the layer and activation function
return x
网络结构如下:
3.5 定义训练参数
# Dataloader (adjust batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Number of epochs
n_epochs = 30
# Create the network
net = UNet2DModel(
sample_size=28, # Target image resolution
in_channels=1, # Input channels, 3 for RGB images
out_channels=1, # Output channels
layers_per_block=2, # ResNet layers per UNet block
block_out_channels=(32, 64, 64), # Matching our basic UNet example
down_block_types=(
"DownBlock2D", # Regular ResNet downsampling block
"AttnDownBlock2D", # ResNet block with attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # ResNet block with attention
"UpBlock2D", # Regular ResNet upsampling block
),
)
net.to(device)
# Loss function
loss_fn = nn.MSELoss()
# Optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Track losses
losses = []
# Training loop
for epoch in range(n_epochs):
for x, y in train_dataloader:
x = x.to(device) # Data on GPU
noise_amount = torch.rand(x.shape[0]).to(device) # Random noise amount
noisy_x = corrupt(x, noise_amount) # Create noisy input
# Get model prediction
pred = net(noisy_x, 0).sample # Use timestep 0
# Calculate the loss
loss = loss_fn(pred, x) # Compare to original clean image
# Backprop and update parameters
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f"Epoch {epoch}. Average loss: {avg_loss:.5f}")
3.6 测试模型
训练完成后,我们可以测试扩散模型的强大能力。
n_steps = 40
x = torch.rand(8, 1, 28, 28).to(device) # Start from random noise
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): # No gradients during inference
pred = net(x,0).sample # Predict denoised image
pred_output_history.append(pred.detach().cpu()) # Store prediction
mix_factor = 1 / (n_steps - i) # Mix towards prediction
x = x * (1 - mix_factor) + pred * mix_factor # Move partway to the denoised image
step_history.append(x.detach().cpu()) # Store for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 32), sharex=True)
axs[0, 0].set_title("Input (noisy)")
axs[0, 1].set_title("Model Prediction")
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")
这展示了模型测试的最终结果,输出看起来相当令人印象深刻!
总结
至此,扩散模型的解释就结束了。你可以尝试修改UNet网络或调整参数,看看能否获得更出色的结果!
如果你觉得这篇博文有帮助,请考虑点赞 🤗。