OneFormer:一个Transformer 模型统治所有图像分割任务
简介
OneFormer 是一种突破性的图像分割方法,图像分割是计算机视觉任务中一项涉及将图像划分为有意义的片段的任务。传统方法使用不同的模型和架构来处理不同的分割任务,例如识别物体(实例分割)或标记区域(语义分割)。最近的尝试旨在使用共享架构统一这些任务,但仍然需要针对每个任务进行单独训练。
OneFormer 是一种通用的图像分割框架,旨在克服这些挑战。它引入了独特的多任务方法,允许单个模型处理语义、实例和全景分割任务,而无需针对每个任务进行单独训练。关键创新在于任务条件联合训练策略,其中模型由任务输入引导,使其在训练和推理过程中都能动态适应不同的任务。
这种突破不仅简化了训练过程,而且在各种数据集上超越了现有模型。OneFormer 通过使用全景注释来实现这一目标,将所有任务所需的真实信息统一起来。此外,该框架引入了查询文本对比学习,以更好地区分任务并提高整体性能。
OneFormer 的背景
图片来自 OneFormer 论文
为了理解 OneFormer 的意义,让我们深入了解图像分割的背景。在图像处理中,分割涉及将图像划分为不同的部分,这对于识别物体和理解场景内容等任务至关重要。传统上,语义分割和实例分割是图像分割的两大主要任务。语义分割将像素标记为“道路”或“天空”等类别,而实例分割则识别具有明确边界的目标。
随着时间的推移,研究人员提出了全景分割,它试图统一语义分割和实例分割任务。然而,即使有了这些进步,也面临着挑战。为全景分割设计的现有模型仍然需要针对每个任务进行单独训练,这使得它们充其量只能算是半通用的。
这就是 OneFormer 作为游戏规则改变者的出现。它引入了新颖的方法——多任务通用架构。其理念是只训练一次该框架,使用单个通用架构,一个模型和一个数据集。OneFormer 的神奇之处在于它在语义、实例和全景分割任务中超越了专门的框架。这种突破不仅仅是关于提高准确性;它还关于使图像分割更加通用和高效。使用 OneFormer,针对不同任务进行大量资源和单独训练的需求将成为过去。
OneFormer 的核心概念
现在,让我们分解 OneFormer 的关键特性,使其脱颖而出。
任务动态掩码
OneFormer 使用了一个巧妙的技巧,称为“任务动态掩码”,以更好地理解和处理不同类型的图像分割任务。因此,当模型遇到图像时,它会使用此“任务动态掩码”来决定是关注整个场景、识别具有清晰边界的特定物体,还是同时做这两件事。
任务条件联合训练
OneFormer 的一项突破性功能是其**任务条件联合训练策略**。与分别训练语义分割、实例分割和全景分割不同,OneFormer 在训练期间统一采样任务。这种策略使模型能够同时学习和泛化不同的分割任务。通过在任务令牌中对特定任务进行架构条件,OneFormer 统一了训练过程,减少了对特定任务架构、模型和数据集的需求。这种创新方法极大地简化了训练流程,降低了资源需求。
查询-文本对比损失
最后,我们来谈谈“**查询-文本对比损失**”。可以把它看作是 OneFormer 自学任务和类别之间差异的一种方式。在训练过程中,模型会比较从图像中提取的特征(查询)与相应的文本描述(例如“一张有汽车的照片”)。这有助于模型理解每个任务的独特特征,并减少不同类别之间的混淆。OneFormer 的“**任务动态掩码**”使其像多任务助手一样灵活多变,而“**查询-文本对比损失**”则通过将视觉特征与文本描述进行比较,帮助其学习每个任务的细节。
通过结合这些核心概念,OneFormer 成为一种用于图像分割的智能且高效的工具,使分割过程更加通用和易于访问。
结论
取自 OneFormer 论文的图片
总之,OneFormer 框架代表了图像分割领域的一种突破性方法,旨在简化和统一跨不同领域的分割任务。与传统的依赖于每个分割任务的专门架构的方法不同,OneFormer 引入了一种新颖的多任务通用架构,该架构只需要一个模型,在通用数据集上训练一次,即可超越现有的框架。此外,在训练过程中引入查询-文本对比损失提高了模型学习任务间和类别间差异的能力。OneFormer 利用基于 Transformer 的架构,借鉴了近期计算机视觉领域的成功经验,并引入了任务引导查询以提高任务敏感性。结果令人印象深刻,因为 OneFormer 在 ADE20k、Cityscapes 和 COCO 等基准数据集上超越了语义分割、实例分割和全景分割任务中最先进的模型。该框架的性能通过新的 ConvNeXt 和 DiNAT 主干进一步增强。
总之,OneFormer 代表着朝着通用和易于访问的图像分割迈出的重要一步。通过引入能够处理各种分割任务的单一模型,该框架简化了分割过程,降低了资源需求。
模型使用示例
让我们看看模型的使用示例。Dinat 主干需要 Natten 库,安装可能需要一些时间。
!pip install -q natten
我们可以在下面看到根据不同分割类型的推断代码。
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
def run_segmentation(image, task_type):
"""Performs image segmentation based on the given task type.
Args:
image (PIL.Image): The input image.
task_type (str): The type of segmentation to perform ('semantic', 'instance', or 'panoptic').
Returns:
PIL.Image: The segmented image.
Raises:
ValueError: If the task type is invalid.
"""
processor = OneFormerProcessor.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
) # Load once here
model = OneFormerForUniversalSegmentation.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
)
if task_type == "semantic":
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
elif task_type == "instance":
inputs = processor(images=image, task_inputs=["instance"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_instance_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
elif task_type == "panoptic":
inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_panoptic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
else:
raise ValueError(
"Invalid task type. Choose from 'semantic', 'instance', or 'panoptic'"
)
return predicted_map
def show_image_comparison(image, predicted_map, segmentation_title):
"""Displays the original image and the segmented image side-by-side.
Args:
image (PIL.Image): The original image.
predicted_map (PIL.Image): The segmented image.
segmentation_title (str): The title for the segmented image.
"""
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_map)
plt.title(segmentation_title + " Segmentation")
plt.axis("off")
plt.show()
url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP errors
image = Image.open(response.raw)
task_to_run = "semantic"
predicted_map = run_segmentation(image, task_to_run)
show_image_comparison(image, predicted_map, task_to_run)