理解 VQ-VAE 中的向量量化
向量量化变分自编码器 (VQ-VAE) 利用一种独特的机制,即向量量化,将连续的潜在表示映射为离散嵌入。在本文中,我将尝试以更实际的方式解释该机制。
初始化层
class VQEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
VQEmbedding
类旨在创建和管理嵌入矩阵(码本嵌入),其中每行代表模型可以选择的可能离散嵌入。此矩阵的形状由 num_embeddings
(嵌入数量)和 embedding_dim
(每个嵌入向量的大小)定义。
初始化过程的关键部分是使用均匀分布设置嵌入权重。具体来说,每个权重被分配一个介于 -1/self.num_embeddings
和 1/self.num_embeddings
之间的值,确保初始值在此范围内均匀分布。这种均匀初始化很重要,因为它能防止训练开始时出现任何偏差。通过避免过大或过小的初始值,模型以中性状态开始,这有助于促进平衡学习。
扁平化以获得灵活性
def forward(self, z):
b, c, h, w = z.shape
z_channel_last = z.permute(0, 2, 3, 1)
z_flattened = z_channel_last.reshape(b*h*w, self.embedding_dim)
向量量化的第一步是扁平化编码输入。通常,来自图像的编码输入具有 [Batch, embedding_dim, h, w]
的形状。通过扁平化此张量,我们将其转换为 [Batch * h * h, embedding_dim]
。此转换不仅简化了后续操作,还使模块具有通用性,可兼容各种输入形状。
距离计算
向量量化的核心在于编码向量和码本嵌入之间的距离计算。为了计算距离,我们使用均方误差(MSE)损失。两个向量 (原始向量)和 (量化向量)之间的 MSE 可以表示为
其中:
- 是向量中的元素数量。
- 和 是向量 和 的对应元素。
此 MSE 损失可以使用差的平方公式重写
将此代入 MSE 公式,我们得到
# Calculate distances between z and the codebook embeddings |a-b|²
distances = (
torch.sum(z_flattened ** 2, dim=-1, keepdim=True) # a²
+ torch.sum(self.embedding.weight.t() ** 2, dim=0, keepdim=True) # b²
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()) # -2ab
)
这里,理解矩阵的形状至关重要
- 扁平化的编码输入形状为
[b*h*w, embedding_dim]
。 - 嵌入矩阵(码本)的权重形状为
[num_embeddings, embedding_dim]
。
通过仔细转置,我们确保操作正确对齐,从而得到形状为 [b*h*w, num_embeddings]
的距离矩阵。此矩阵包含每个编码输入向量与所有码本嵌入之间的距离。
选择最近的码本嵌入
一旦我们有了距离矩阵,下一步就是识别每个向量的最小距离索引。这种选择过程,虽然让人联想到注意力机制(主要区别在于注意力机制侧重于最大值),但它允许我们将每个输入向量映射到其最近的码本条目。
# Get the index with the smallest distance
encoding_indices = torch.argmin(distances, dim=-1)
量化和重塑
有了最近码本嵌入的索引,我们使用 PyTorch 的 nn.Embedding
模块检索量化向量。这些向量的形状现在为 [b*h*w, embedding_dim]
,它们被重塑回原始空间维度并传递给解码器。
# Get the quantized vector
z_q = self.embedding(encoding_indices)
z_q = z_q.reshape(b, h, w, self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2)
损失和梯度流
在阅读 VQ-VAE 时,我发现码本思想并不是最突出的,而是作者如何设法传播梯度以使模型实现端到端训练。
在 VQ-VAE 中,承诺损失(commitment loss)在确保编码器网络提交到准确表示输入的特定码本条目方面起着关键作用。如果没有这种承诺,编码器可能会产生与可用码本条目不完全对齐的输出,从而导致重建质量不佳。承诺损失通常是连续编码向量与其对应量化版本之间的均方误差(MSE)。其思想是,当编码器的输出偏离所选码本条目太远时,惩罚编码器,鼓励编码器生成更接近码本中离散嵌入的表示。此损失项有助于稳定训练,并确保编码器和码本协同工作,从而提高学习表示的整体质量。
# Calculate the commitment loss
loss = F.mse_loss(z_q, z.detach()) + commitment_cost * F.mse_loss(z_q.detach(), z)
# Straight-through estimator trick for gradient backpropagation
z_q = z + (z_q - z).detach()
return z_q, loss, encoding_indices
直通估计器是一种巧妙的技术。挑战在于,将连续向量映射到离散码本条目的量化过程是不可微的。这种不可微性阻碍了梯度在网络中的反向传播,使得使用标准反向传播训练模型变得困难。直通估计器通过允许梯度绕过不可微的量化步骤来解决此问题。具体而言,它在反向传播过程中将离散量化输出视为连续输出,从而有效地将梯度从量化向量复制到原始连续向量。这种技巧使得模型能够端到端训练,尽管存在离散变量,但仍保持了基于梯度的优化的优势。
通过将直通估计器与承诺损失相结合,VQ-VAE 成功地平衡了离散表示的需求和基于梯度的优化的优势,使模型能够学习丰富、量化的嵌入,这些嵌入既适用于下游任务,又易于在训练期间优化。
整合
class VQEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
def forward(self, z):
b, c, h, w = z.shape
z_channel_last = z.permute(0, 2, 3, 1)
z_flattened = z_channel_last.reshape(b*h*w, self.embedding_dim)
# Calculate distances between z and the codebook embeddings |a-b|²
distances = (
torch.sum(z_flattened ** 2, dim=-1, keepdim=True) # a²
+ torch.sum(self.embedding.weight.t() ** 2, dim=0, keepdim=True) # b²
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()) # -2ab
)
# Get the index with the smallest distance
encoding_indices = torch.argmin(distances, dim=-1)
# Get the quantized vector
z_q = self.embedding(encoding_indices)
z_q = z_q.reshape(b, h, w, self.embedding_dim)
z_q = z_q.permute(0, 3, 1, 2)
# Calculate the commitment loss
loss = F.mse_loss(z_q, z.detach()) + commitment_cost * F.mse_loss(z_q.detach(), z)
# Straight-through estimator trick for gradient backpropagation
z_q = z + (z_q - z).detach()
return z_q, loss, encoding_indices
也可以访问此仓库,查看在 CIFAR10 数据集上训练 VQ-VAE 的情况。