cocogold:训练 Marigold 进行文本引导分割

社区文章 发布于 2025 年 7 月 8 日

Marigold 是一种基于扩散的深度估计方法(论文演示),后来扩展到其他任务,例如法线估计(论文演示)。我一直想知道类似的方法是否可以应用于分割,因为它与 Marigold 设计的任务有一些不同。我一直在把它作为一个断断续续的副项目来做,这是我的报告(TL;DR:我训练了一个概念验证模型,它奏效了😎)

我喜欢 Marigold,因为它展示了如何巧妙地利用现有开放模型(Stable Diffusion)并将其微调用于不同任务。由于 Stable Diffusion 在大量图像上进行了训练,它的图像理解能力非常出色。通过重用这种丰富的图像表示知识,Marigold 可以在短短几天内使用单个消费级 GPU 进行微调。Stable Diffusion 是一个很好的选择,因为大部分计算发生在潜在空间而不是直接处理图像像素。潜在空间是一种花哨的说法,意味着我们能够将输入图像大幅压缩(在这种情况下是48倍),因此计算速度更快,所需内存也少得多。

然而,Marigold 仅仅将 Stable Diffusion(SD)用作视觉主干,完全忽略了 SD 同样能够理解描述视觉内容的文本。这是因为深度或法线计算是纯计算机视觉任务。与 Stable Diffusion 旨在根据文本描述生成图像不同,Marigold 使用原始图像作为模型的输入来生成图像(深度图)。这类似于图像到图像的生成任务,但我们只需要输入图像——不需要文本描述。

我想解决的问题是:我们是否可以使用类似的方法来估计分割掩码,并使用文本来描述我们想要查找的对象?

cocogold task

结果是可以!我做了一个概念验证作为副项目。为了简单起见,我使用了 COCO 数据集,并对 Marigold 方法进行了一些调整。我的副项目由此诞生,名为 cocogold

cocogold 有什么用?

如上面大象图片所示,cocogold 能够根据图片中任意对象的文本描述来估计其分割掩码。然而,并非所有图片都像这张一样容易。有时我们想要提取的对象并非占据照片大部分的大象。请看以下来自我们训练过程中使用的验证集的例子(我们将掩码显示为图片上叠加的白色区域,原因稍后解释)

cocogold examples

正如你所看到的,它适用于照片中不显眼的对象,包括小物体和部分被遮挡的物体。令人惊讶的是,它还能泛化到未见过的类别——该模型从未被训练识别大象,但正如你在这篇文章的第一个例子中看到的,它能够做到!

这只是一个实验性的概念验证,并非最先进的分割方法。但它是利用图像生成模型解决文本引导计算机视觉任务的一种非常有效且有趣的方式!

它是如何工作的?

为了了解 cocogold 的工作原理,让我们看看它与 Marigold 的不同之处。

Marigold 方法

Marigold 基于 Stable Diffusion UNet,它是 Stable Diffusion 管道中的一个模型。在 Stable Diffusion 中,UNet 接受两个输入:一个带有噪声的输入图像和一个文本提示。UNet 的工作,从概念上讲,是预测一个噪声较少的输入图像。它被设计成一个循环:从纯高斯噪声开始,它逐步从噪声中雕刻出图像,使用文本描述作为指导,以达到我们想要的结果。

Stable Diffusion Denoising
针对提示词“Pedro Cuenca 的快乐肖像画”的去噪过程。免责声明:我不是那个人。

Marigold 使用了修改后的 UNet

  • 文本输入被忽略。
  • 它接受两张图像作为输入,而不是文本+图像。第一张图像是原始图像,我们希望从中估计深度图。它不是文本提示,而是用于调节和引导生成过程的。第二张图像是带有噪声的:我们从纯高斯噪声开始,模型逐渐预测深度图,使用原始图像作为参考。

cocogold

cocogold 的管道在概念上非常简单——但通常情况下,魔鬼藏在细节中。对于训练,我们将使用三个输入

  • 指定要分割对象的文本输入。原始的 Stable Diffusion UNet 已准备好与文本配合使用,因此我们只需传递 Marigold 忽略的文本输入即可。
  • 原始图像。
  • COCO 数据集中的分割掩码,我们将用作真实值。

数据集准备

为了准备训练数据集,我下载了原始 COCO 数据集,并将全景分割掩码(图像中的每个像素都分配给一个支持的类别)转换为对象分割(我们只对与给定类别匹配的像素感兴趣)。然后我写了一个简单的库,它执行以下操作

  • 检索图像和类别。
  • 获取与该类别对应的分割掩码。
  • 随机裁剪一个正方形以方便训练。
  • 应用一些简单的启发式方法,以避免在应用随机裁剪后被严重裁剪的掩码,并为更大的掩码分配更高的优先级。

我将此封装在 PyTorch 的 Dataset 类中。每次迭代数据集时,都会获得不同的裁剪和掩码,因此我在训练期间没有费心进行额外的增强。

我使用 nbdev 编写了这个库,因此你可以在探索笔记本中详细了解该过程。

训练

训练是使用 Marigold 的这个分支完成的,我在此分支中应用了上述修改并创建了一个特定的训练脚本。代码很丑陋,因为我硬编码了大部分所需内容,而没有担心与 Marigold 保持兼容性。此外,我在几个月前分叉了它,此后再也没有与他们的代码库同步。我总共花了大约 5 或 6 天的时间,但整个过程持续了数周,因为我只是偶尔回来处理一下。

比我的糟糕工程实践更有趣的是,我的第一次训练运行失败了。

cocogold failed run

训练运行看起来很有希望:仅在几步之后,我就得到了合理的掩码。然而,更长时间的训练并没有改善结果,我看到了许多退化的情况和在整个过程中未能识别的对象。我对这种情况发生的原因有一些猜测

  • 模型被训练来预测两个值:背景(显示为黑色)和掩码。在大多数训练案例中,一个区域比另一个区域占据更大的面积——通常背景比要分割的主体更大。由于我使用 MSE 作为损失函数,我认为模型可能正在学习预测最频繁的值以最小化损失。

这只在概念层面是正确的,因为 MSE 损失是在潜在空间中计算的,而不是比较解码后的像素。但即便如此,数据集样本仍然不平衡,即使在潜在空间中也是如此。

  • 我怀疑有些物体通常很小,模型可能会了解到,例如,“香蕉”总是指小物体。我没有完全验证这是否一致。
  • 数据集也存在严重的类别不平衡,其中 person 类别过多。这是对整个训练数据集的一次完整迭代的分析(请记住样本是随机的,因此数值不一定精确)

cocogold dataset class distribution

所以长时间训练可能会导致对 person 主题的过拟合。我应该早点进行这项分析:(

为了解决第一个问题,我尝试调整我的损失函数,使其模型更难选择多数类别。事实证明,这方面已经有研究(当然),人们为此目的使用类别加权焦点损失。但这些方法在像素空间中有意义,我不想用像素计算损失——要做到这一点,我必须解码潜在变量以将其转换为图像,这既昂贵又占用大量内存。我为潜在空间编写了一个非常朴素的焦点损失函数,但它也不起作用。

修复

在与损失函数搏斗了一段时间后,我突然意识到解决方案可能简单得多!与其强制模型学习在两个值之间进行选择,我意识到我可以控制数据并以不同方式准备输入。我没有提供二值分割掩码作为真实值,而是在原始图像上绘制了掩码。这对于一个经过数百万步训练以生成由像素组成的图像的模型来说,更加自然,因为它以前很少遇到二值掩码!

此外,我还将训练类别限制为以下几类

valid_categories = [
    "car", "dining table", "chair", "train", "airplane",
    "giraffe", "clock", "toilet", "bed", "bird",
    "truck", "cat", "horse", "dog",
]

我排除了严重过多的 person 类别,并保留了前 14 个剩余类别。

通过这些修改,训练进展顺利,结果如开头所示。模型学会了预测几乎与输入图像完全相同的副本,只是它使用白色块来标记我们感兴趣的分割对象。

需要多长时间?

使用像 Marigold 这样专注的方法的优势之一是迭代速度快,因为我们利用了一个对图像表示了解很多且能快速学习新任务的模型。我在一个 A6000 Ada GPU 上训练了大约 40 小时(约 18,000 步)——这就像一个 4090,但有 48 GB 内存而不是 24 GB。在约 5,000 步时结果就相当不错了,但我继续训练更长时间并保存中间检查点,因为我想稍后在各个阶段进行测试。

我使用了 float32,没有 LoRA。训练速度可能会快很多,但我不想在最初的实验阶段引入不确定性。

推理后处理

我们还没有完成。我们成功创建了一个模型,它不会预测分割掩码,而是在图像顶部预测分割掩码(白色)。我们实际上需要从预测图像创建二值分割掩码,因为这是我们下游任务所需的。

为了创建分割掩码,我的方法是过滤输出,保留偏白像素,丢弃其余部分。我们必须以一定的容差进行过滤,因为模型不会为掩码中的每个像素预测完美的 1.0 像素,而是像 0.9970.935 这样的值。然而,这有一个问题:真实图像也包含白色像素。它们很少是纯白色(所有通道都接近 1.0),但如果你曾见过在阳光下拍摄的、天空过曝泛白的图片,找到白色像素的可能性并不低。为了说明这个问题,这是选择偏白像素时的典型结果

Computing mask from prediction

如您所见,在不相关的地方有几个白色像素异常值。

为了解决这个问题,我求助于仍然非常有用的老式图像处理算法。对于小块离群斑点的情况,我们可以使用腐蚀操作(像素被附近的多数类别像素替换,去除噪声),然后是膨胀操作(我们通过在边界添加像素来扩展形状)。这会产生或多或少相同的形状,同时去除了离群值。我使用了一个小型的 3×3 卷积核来完成此操作,但您也可以通过池化来实现相同效果。这是经过后处理的、去除了离群值的掩码

Morphological opening

腐蚀和膨胀是形态学操作家族的一部分。腐蚀后接膨胀的组合也称为形态学开运算。您不需要知道这些名称,除非您在论文或其他地方遇到它们。

然而,这种算法对于实际包含接近白色像素的图像(例如那些天空过曝的图像)会失败。以下时钟图像示例说明了一个失败案例

Clock failing example

为了解决这个问题,我通过预处理输入来“作弊”。在运行推理之前,我将偏白像素去饱和,以确保输入中不包含白色像素。因为模型经过训练能够准确复制输入,所以输出中也不会有白色像素——除了掩码。这是去饱和图像和推断掩码的样子

Clock: desaturating before prediction

回想起来,我应该为掩码选择不同的颜色。纯绿色,就像视频场景中用于色度键的颜色,也许可以——如果它足以用于电视,那么用于训练也应该没问题,对吗?也许我可以快速微调模型来替换颜色,看看它是否可以在不后处理掩码或预处理输入的情况下工作。

泛化

如前所述,该模型从未训练过大象等类别,但它却能奏效。这证明了在 Stable Diffusion 原始训练过程中实现的图像和文本的强大表示;我们无需做任何事情,它就能直接工作。实际上,我发现有趣的是,早期的一些模型检查点,训练步骤很少,就能够预测出非常好的掩码,但在泛化方面表现不佳。经过更长时间的训练,模型在泛化方面才得以改进。我目前还不清楚为什么会是这种情况。

老兄,为什么不直接用 VLM?

VLM(视觉-语言模型)将“视觉理解”模型与 LLM 结合起来,因此它们可以根据我们在图像中看到的内容和提出的问题生成文本。它们正成为解决众多任务的通用解决方案,包括视觉问答、OCR、图像标注等等。

我对 VLM 的直觉是,它们肯定可以用于分割,就像 PaliGemma 所展示的那样。然而,PaliGemma 在预训练过程中使用了大量数据来学习几何感知任务。为通用 VLM 注入足够的知识以执行分割可能需要比我们在这个实验性训练运行中花费更多的数据和时间。此外,VLM 只能输出文本,因此在训练过程中需要考虑这一点:向模型添加新 token,正确初始化它们,创建一些东西(自定义 VAE)来编码/解码它们。

我认为扩散模型更容易,因为它们已经了解所有这些东西,所以训练快速而直接。但比较一下会很酷!

还有一种选择是使用常规分割模型,但挑战在于如何使其对文本做出适当响应。这当然是可能的,但您必须仔细设计您的文本编码、对齐、检测和分割流程。

资源汇总

接下来的计划

还有一些事情我还没做完,完成它们会很酷

  • 使用 IoU 等分割指标评估模型,看看它到底有多好。我只是想探索我们是否可以使用扩散训练一个文本引导的分割模型,但我不知道它的性能如何。
  • 衡量模型在我们训练中使用的类别和未使用的类别上的表现如何。
  • 训练更多类别。
  • 推理时的集成。由于扩散过程是随机的,每次运行推理时结果略有不同。原始的用于深度估计的 Marigold 使用几次预测的中位数(或平均值),但我尚未在此任务上测试。
  • 通过使用同义词或不同的短语来提高文本理解能力,而不仅仅是 COCO 类别名称。更好的方法是,我们可以使用 VLM 为所需对象创建一个简短的标题,并用它进行训练。在推理过程中,我们可以使用“请选择离门最近的女孩”之类的描述来给出指令。这应该会提高泛化能力,而且非常酷。

谢谢!❤️

特别感谢 Aritra 热情地投入到这个项目中并使其变得更好,期待在未来我们讨论的一些话题上继续合作!🫶

非常感谢我的 Hugging Frace 朋友和同事们 vbSergioSayakMerve 讨论这些想法,阅读这篇博文的草稿并提供宝贵建议。

这个项目如果没有 Marigold 是不可能的,感谢 Anton 和团队与社区分享,并给予灵感!

社区

真是个不错的项目。

这太棒了,谢谢分享!

注册登录 发表评论