使用 CLIPSeg 进行零样本图像分割

发布日期:2022 年 12 月 21 日
在 GitHub 上更新
Open In Colab

本指南展示了如何使用 CLIPSeg(一个零样本图像分割模型)和 🤗 transformers。CLIPSeg 创建的粗略分割掩码可用于机器人感知、图像修复和许多其他任务。如果您需要更精确的分割掩码,我们将展示如何在 Segments.ai 上优化 CLIPSeg 的结果。

图像分割是计算机视觉领域的一个众所周知的任务。它使计算机不仅能知道图像中有什么(分类),物体在图像中的位置(检测),还能知道这些物体的轮廓。了解物体的轮廓在机器人和自动驾驶等领域至关重要。例如,机器人必须知道物体的形状才能正确抓取它。分割还可以与 图像修复 结合,允许用户描述他们想要替换的图像部分。

大多数图像分割模型的一个限制是它们只适用于固定的类别列表。例如,您不能简单地使用在橙子上训练的分割模型来分割苹果。要教授分割模型一个额外的类别,您必须标记新类别的数据并训练一个新模型,这可能成本高昂且耗时。但是,如果有一个模型已经可以分割几乎任何种类的物体,而无需进一步训练,那会怎么样?这正是 CLIPSeg(一个零样本分割模型)所实现的。

目前,CLIPSeg 仍有其局限性。例如,该模型使用 352 x 352 像素的图像,因此输出分辨率相当低。这意味着当我们使用现代相机拍摄的图像时,我们无法期望获得像素级完美的结果。如果我们需要更精确的分割,我们可以对最先进的分割模型进行微调,如 我们之前的博客文章 所示。在这种情况下,我们仍然可以使用 CLIPSeg 生成一些粗略的标签,然后在使用 Segments.ai 等标注工具中对其进行优化。在我们描述如何做到这一点之前,我们首先看看 CLIPSeg 的工作原理。

CLIP:CLIPSeg 背后的神奇模型

CLIP,全称是 Contrastive Language–Image Pre-training(对比语言-图像预训练),是 OpenAI 于 2021 年开发的一个模型。您可以给 CLIP 输入一张图像或一段文本,CLIP 将输出您输入的抽象表示。这种抽象表示,也称为嵌入,实际上只是一个向量(一串数字)。您可以将此向量视为高维空间中的一个点。CLIP 经过训练,使得相似图像和文本的表示也相似。这意味着如果我们输入一张图像和一段与该图像匹配的文本描述,图像和文本的表示将相似(即,高维点将彼此靠近)。

起初,这可能看起来不是很有用,但它实际上非常强大。举个例子,让我们快速看看 CLIP 如何用于图像分类,而从未在该任务上进行过训练。为了对图像进行分类,我们将图像和我们想要选择的不同类别输入到 CLIP(例如,我们输入一张图像和“苹果”、“橙子”等词)。然后 CLIP 返回图像和每个类别的嵌入。现在,我们只需检查哪个类别嵌入与图像的嵌入最接近,瞧!感觉像魔法,不是吗?

使用 CLIP 进行图像分类的示例 (来源)。

更重要的是,CLIP 不仅对分类有用,它还可以用于图像搜索(您能看出这与分类有何相似之处吗?)、文本到图像模型DALL-E 2 由 CLIP 提供支持)、对象检测OWL-ViT),以及对我们来说最重要的是:图像分割。现在您明白为什么 CLIP 确实是机器学习领域的一个突破了。

CLIP 之所以效果如此好,是因为该模型在包含文本标题的大型图像数据集上进行了训练。该数据集包含多达 4 亿张从互联网上获取的图像-文本对。这些图像包含各种各样的对象和概念,CLIP 在为它们中的每一个创建表示方面表现出色。

CLIPSeg:使用 CLIP 进行图像分割

CLIPSeg 是一个使用 CLIP 表示来创建图像分割掩码的模型。它由 Timo Lüddecke 和 Alexander Ecker 发布。他们通过在 CLIP 模型之上训练一个基于 Transformer 的解码器来实现零样本图像分割,该解码器保持冻结状态。解码器接收图像的 CLIP 表示以及您要分割的事物的 CLIP 表示。使用这两个输入,CLIPSeg 解码器创建一个二进制分割掩码。更精确地说,解码器不仅使用我们要分割的图像的最终 CLIP 表示,它还使用 CLIP 的某些层的输出。

源码

该解码器在 PhraseCut 数据集上进行训练,该数据集包含超过 340,000 个短语及其对应的图像分割掩码。作者还尝试了各种增强方法来扩大数据集的大小。这里的目标不仅是能够分割数据集中存在的类别,还要能够分割未见过的类别。实验确实表明解码器可以泛化到未见过的类别。

CLIPSeg 的一个有趣特性是,查询(我们要分割的图像)和提示(我们要分割的图像中的事物)都作为 CLIP 嵌入输入。提示的 CLIP 嵌入可以来自一段文本(类别名称),**或者来自另一张图像**。这意味着您可以通过向 CLIPSeg 提供一个橙子的示例图像来分割照片中的橙子。

这种技术被称为“视觉提示”,当您想要分割的事物难以描述时,它非常有用。例如,如果您想分割 T 恤图片中的徽标,描述徽标的形状并不容易,但 CLIPSeg 允许您只需使用徽标的图像作为提示。

CLIPSeg 论文包含了一些关于提高视觉提示有效性的技巧。他们发现裁剪查询图像(使其仅包含您要分割的对象)非常有帮助。模糊和调暗查询图像的背景也有一些帮助。在下一节中,我们将展示如何使用 🤗 transformers 亲自尝试视觉提示。

将 CLIPSeg 与 Hugging Face Transformers 结合使用

使用 Hugging Face Transformers,您可以轻松下载并在图像上运行预训练的 CLIPSeg 模型。让我们从安装 transformers 开始。

!pip install -q transformers

要下载模型,只需实例化它即可。

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

现在我们可以加载一张图像来尝试分割。我们将选择一张由 Calum Lewis 拍摄的美味早餐图片。

from PIL import Image
import requests

url = "https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640"
image = Image.open(requests.get(url, stream=True).raw)
image

文本提示

让我们从定义一些我们想要分割的文本类别开始。

prompts = ["cutlery", "pancakes", "blueberries", "orange juice"]

现在我们有了输入,我们可以处理它们并将它们输入到模型中。

import torch

inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
  outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)

最后,让我们可视化输出。

import matplotlib.pyplot as plt

_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];

视觉提示

如前所述,我们也可以使用图像作为输入提示(即代替类别名称)。如果难以描述您要分割的事物,这会特别有用。在此示例中,我们将使用由 Daniel Hooper 拍摄的咖啡杯图片。

url = "https://unsplash.com/photos/Ki7sAc8gOGE/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTJ8fGNvZmZlJTIwdG8lMjBnb3xlbnwwfHx8fDE2NzExOTgzNDQ&force=true&w=640"
prompt = Image.open(requests.get(url, stream=True).raw)
prompt

现在我们可以处理输入图像和提示图像,并将它们输入到模型中。

encoded_image = processor(images=[image], return_tensors="pt")
encoded_prompt = processor(images=[prompt], return_tensors="pt")
# predict
with torch.no_grad():
  outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)

然后,我们可以像以前一样可视化结果。

_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))

让我们最后一次尝试使用论文中描述的视觉提示技巧,即裁剪图像和调暗背景。

url = "https://i.imgur.com/mRSORqz.jpg"
alternative_prompt = Image.open(requests.get(url, stream=True).raw)
alternative_prompt
encoded_alternative_prompt = processor(images=[alternative_prompt], return_tensors="pt")
# predict
with torch.no_grad():
  outputs = model(**encoded_image, conditional_pixel_values=encoded_alternative_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)
_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))

在这种情况下,结果几乎相同。这可能是因为咖啡杯在原始图像中已经与背景很好地分离了。

使用 CLIPSeg 在 Segments.ai 上预标注图像

如您所见,CLIPSeg 的结果有点模糊且分辨率很低。如果我们需要获得更好的结果,您可以微调最先进的分割模型,如 我们之前的博客文章 中所述。为了微调模型,我们需要标注数据。在本节中,我们将向您展示如何使用 CLIPSeg 创建一些粗略的分割掩码,然后在 Segments.ai(一个带有智能图像分割标注工具的标注平台)上对其进行优化。

首先,在 https://segments.ai/join 创建一个帐户并安装 Segments Python SDK。然后,您可以使用 API 密钥初始化 Segments.ai Python 客户端。此密钥可以在 帐户页面 上找到。

!pip install -q segments-ai
from segments import SegmentsClient
from getpass import getpass

api_key = getpass('Enter your API key: ')
segments_client = SegmentsClient(api_key)

接下来,让我们使用 Segments 客户端从数据集中加载一张图像。我们将使用 a2d2 自动驾驶数据集。您也可以按照 这些说明 创建自己的数据集。

samples = segments_client.get_samples("admin-tobias/clipseg")

# Use the last image as an example
sample = samples[1]
image = Image.open(requests.get(sample.attributes.image.url, stream=True).raw)
image

我们还需要从数据集属性中获取类别名称。

dataset = segments_client.get_dataset("admin-tobias/clipseg")
category_names = [category.name for category in dataset.task_attributes.categories]

现在我们可以像以前一样在图像上使用 CLIPSeg。这次,我们还会将输出放大,使其与输入图像的大小匹配。

from torch import nn

inputs = processor(text=category_names, images=[image] * len(category_names), padding="max_length", return_tensors="pt")

# predict
with torch.no_grad():
  outputs = model(**inputs)

# resize the outputs
preds = nn.functional.interpolate(
    outputs.logits.unsqueeze(1),
    size=(image.size[1], image.size[0]),
    mode="bilinear"
)

然后我们可以再次可视化结果。

len_cats = len(category_names)
_, ax = plt.subplots(1, len_cats + 1, figsize=(3*(len_cats + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len_cats)];
[ax[i+1].text(0, -15, category_name) for i, category_name in enumerate(category_names)];

现在我们必须将预测结果组合成一个单一的分割图像。我们只需对每个补丁选择 Sigmoid 值最大的类别即可。我们还会确保低于某个阈值的所有值都不计入在内。

threshold = 0.1

flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))

# Initialize a dummy "unlabeled" mask with the threshold
flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)
flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds

# Get the top mask index for each pixel
inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))

让我们快速可视化结果。

plt.imshow(inds)

最后,我们可以将预测上传到 Segments.ai。为此,我们首先将位图转换为 png 文件,然后将此文件上传到 Segments,最后将标签添加到样本中。

from segments.utils import bitmap2file
import numpy as np

inds_np = inds.numpy().astype(np.uint32)
unique_inds = np.unique(inds_np).tolist()
f = bitmap2file(inds_np, is_segmentation_bitmap=True)

asset = segments_client.upload_asset(f, "clipseg_prediction.png")

attributes = {
      'format_version': '0.1',
      'annotations': [{"id": i, "category_id": i} for i in unique_inds if i != 0],
      'segmentation_bitmap': { 'url': asset.url },
  }

segments_client.add_label(sample.uuid, 'ground-truth', attributes)

如果你查看 Segments.ai 上上传的预测,你会发现它并不完美。但是,你可以手动纠正最大的错误,然后你可以使用纠正后的数据集来训练一个比 CLIPSeg 更好的模型。

总结

CLIPSeg 是一种零样本分割模型,可以使用文本和图像提示。该模型在 CLIP 的基础上添加了一个解码器,几乎可以分割任何东西。然而,目前输出的分割掩码分辨率仍然很低,因此如果精度很重要,您可能仍然需要微调不同的分割模型。

请注意,目前正在进行更多关于零样本分割的研究,因此预计在不久的将来会添加更多模型。一个例子是 GroupViT,它已在 🤗 Transformers 中可用。要了解分割研究的最新新闻,您可以关注我们的 Twitter:@TobiasCornille@NielsRogge@huggingface

如果您有兴趣了解如何微调最先进的分割模型,请查看我们之前的博客文章:https://huggingface.co/blog/fine-tune-segformer

社区

注册登录以评论