社区计算机视觉课程文档

OneFormer:一个 Transformer 统治通用图像分割

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

OneFormer:一个 Transformer 统治通用图像分割

介绍

OneFormer 是图像分割领域的一项突破性方法。图像分割是一项计算机视觉任务,涉及将图像划分为有意义的片段。传统方法针对不同的分割任务使用单独的模型和架构,例如识别对象(实例分割)或标记区域(语义分割)。最近的一些尝试旨在通过共享架构来统一这些任务,但仍然需要针对每个任务进行单独的训练。

OneFormer 横空出世,这是一个旨在克服这些挑战的通用图像分割框架。它引入了一种独特的多任务方法,允许单个模型处理语义分割、实例分割和全景分割任务,而无需对每个任务进行单独的训练。其关键创新在于任务条件联合训练策略,其中模型由任务输入引导,使其在训练和推理过程中都能够动态地适应不同的任务。

这项突破不仅简化了训练过程,而且在各种数据集上的性能都优于现有模型。OneFormer 通过使用全景注释来实现这一点,统一了所有任务所需的真值信息。此外,该框架还引入了查询-文本对比学习,以更好地区分任务并提高整体性能。

OneFormer 的背景

Oneformer 方法 图片来自 OneFormer 论文

为了理解 OneFormer 的重要性,让我们回顾一下图像分割的背景。在图像处理中,分割涉及将图像划分为不同的部分,这对于识别对象和理解场景内容等任务至关重要。传统上,有两种主要的分割任务:语义分割,其中像素被标记为“道路”或“天空”等类别;以及实例分割,它识别具有明确边界的对象。

随着时间的推移,研究人员提出了全景分割作为统一语义分割和实例分割任务的一种方法。然而,即使有了这些进步,仍然存在挑战。为全景分割设计的现有模型仍然需要针对每个任务进行单独的训练,这使得它们充其量只能算是半通用。

这就是 OneFormer 作为游戏规则改变者出现的地方。它引入了一种新颖的方法——多任务通用架构。其理念是仅使用单个通用架构、单个模型和一个数据集对该框架进行一次训练。神奇之处在于 OneFormer 在语义分割、实例分割和全景分割任务中都优于专门的框架。这项突破不仅仅是为了提高准确性,而是为了使图像分割更加通用和高效。有了 OneFormer,对不同任务进行大量资源和单独训练的需求将成为过去。

OneFormer 的核心概念

Task Conditioned Joint Training

现在,让我们分解一下 OneFormer 的关键特性,使其脱颖而出

任务动态掩码

OneFormer 使用了一个名为“任务动态掩码”的巧妙技巧,以更好地理解和处理不同类型的图像分割任务。因此,当模型遇到图像时,它会使用这个“任务动态掩码”来决定是关注整体场景、识别具有清晰边界的特定对象,还是两者兼顾。

任务条件联合训练

OneFormer 的突破性特征之一是其任务条件联合训练策略。OneFormer 不是分别针对语义分割、实例分割和全景分割进行训练,而是在训练期间均匀地采样任务。这种策略使模型能够同时学习和泛化到不同的分割任务。通过任务 token 将架构置于特定任务的条件下,OneFormer 统一了训练过程,减少了对特定于任务的架构、模型和数据集的需求。这种创新方法显着简化了训练流程和资源需求。

查询-文本对比损失

最后,让我们谈谈“查询-文本对比损失”。可以将其视为 OneFormer 自学任务和类别之间差异的一种方式。在训练过程中,模型将从图像中提取的特征(查询)与相应的文本描述(如“一张有汽车的照片”)进行比较。这有助于模型理解每个任务的独特特征,并减少不同类别之间的混淆。OneFormer 的“任务动态掩码”使其像多任务助手一样通用,“查询-文本对比损失”通过将视觉特征与文本描述进行比较,帮助其学习每个任务的细节。

通过结合这些核心概念,OneFormer 成为一个智能高效的图像分割工具,使该过程更加通用和易于使用。

结论

结果比较 图片来自 OneFormer 论文

总之,OneFormer 框架代表了图像分割领域的一项突破性方法,旨在简化和统一跨各种领域的任务。与依赖于每个分割任务的专用架构的传统方法不同,OneFormer 引入了一种新颖的多任务通用架构,该架构仅需要单个模型,在通用数据集上训练一次即可超越现有框架。此外,在训练过程中结合查询-文本对比损失增强了模型学习任务间和类间差异的能力。OneFormer 利用基于 Transformer 的架构,灵感来自计算机视觉领域的最新成功,并引入任务引导的查询来提高任务敏感性。结果令人印象深刻,因为 OneFormer 在 ADE20k、Cityscapes 和 COCO 等基准数据集上的语义分割、实例分割和全景分割任务中超越了最先进的模型。使用新的 ConvNeXt 和 DiNAT 主干网络进一步增强了该框架的性能。

总而言之,OneFormer 代表了朝着通用且易于使用的图像分割迈出的重要一步。通过引入能够处理各种分割任务的单个模型,该框架简化了分割过程并减少了资源需求。

模型的使用示例

让我们来看一个模型的使用示例。Dinat 主干网络需要 Natten 库,安装可能需要一段时间。

!pip install -q natten 

我们可以在下面看到一个推理代码,具体取决于不同的分割类型。

from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt


def run_segmentation(image, task_type):
    """Performs image segmentation based on the given task type.

    Args:
        image (PIL.Image): The input image.
        task_type (str): The type of segmentation to perform ('semantic', 'instance', or 'panoptic').

    Returns:
        PIL.Image: The segmented image.

    Raises:
        ValueError: If the task type is invalid.
    """

    processor = OneFormerProcessor.from_pretrained(
        "shi-labs/oneformer_ade20k_dinat_large"
    )  # Load once here
    model = OneFormerForUniversalSegmentation.from_pretrained(
        "shi-labs/oneformer_ade20k_dinat_large"
    )

    if task_type == "semantic":
        inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
        outputs = model(**inputs)
        predicted_map = processor.post_process_semantic_segmentation(
            outputs, target_sizes=[image.size[::-1]]
        )[0]

    elif task_type == "instance":
        inputs = processor(images=image, task_inputs=["instance"], return_tensors="pt")
        outputs = model(**inputs)
        predicted_map = processor.post_process_instance_segmentation(
            outputs, target_sizes=[image.size[::-1]]
        )[0]["segmentation"]

    elif task_type == "panoptic":
        inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="pt")
        outputs = model(**inputs)
        predicted_map = processor.post_process_panoptic_segmentation(
            outputs, target_sizes=[image.size[::-1]]
        )[0]["segmentation"]

    else:
        raise ValueError(
            "Invalid task type. Choose from 'semantic', 'instance', or 'panoptic'"
        )

    return predicted_map


def show_image_comparison(image, predicted_map, segmentation_title):
    """Displays the original image and the segmented image side-by-side.

    Args:
        image (PIL.Image): The original image.
        predicted_map (PIL.Image): The segmented image.
        segmentation_title (str): The title for the segmented image.
    """

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(predicted_map)
    plt.title(segmentation_title + " Segmentation")
    plt.axis("off")
    plt.show()


url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
response = requests.get(url, stream=True)
response.raise_for_status()  # Check for HTTP errors
image = Image.open(response.raw)

task_to_run = "semantic"
predicted_map = run_segmentation(image, task_to_run)
show_image_comparison(image, predicted_map, task_to_run)

semantic segmentation

参考文献

< > 在 GitHub 上更新