社区计算机视觉课程文档
OneFormer:一个 Transformer 统一通用图像分割
并获得增强的文档体验
开始使用
OneFormer:一个 Transformer 统一通用图像分割
简介
OneFormer 是图像分割领域的一项突破性方法,图像分割是一项涉及将图像分割成有意义的片段的计算机视觉任务。传统方法对不同的分割任务使用独立的模型和架构,例如识别对象(实例分割)或标记区域(语义分割)。最近的尝试旨在通过共享架构统一这些任务,但仍然需要为每个任务进行单独的训练。
OneFormer 是一款旨在克服这些挑战的通用图像分割框架。它引入了一种独特的多任务方法,允许单个模型处理语义、实例和全景分割任务,而无需对每个任务进行单独训练。其关键创新在于任务条件联合训练策略,模型通过任务输入进行引导,使其在训练和推理过程中都具有动态性和适应性。
这一突破不仅简化了训练过程,还在各种数据集上超越了现有模型。OneFormer 通过使用全景标注实现了这一点,统一了所有任务所需的真实信息。此外,该框架还引入了查询-文本对比学习,以更好地区分任务并提高整体性能。
OneFormer 背景
图片来自 OneFormer 论文
为了理解 OneFormer 的重要性,让我们深入了解图像分割的背景。在图像处理中,分割涉及将图像划分为不同的部分,这对于识别对象和理解场景内容等任务至关重要。传统上,两种主要类型的分割任务是语义分割(其中像素被标记为“道路”或“天空”等类别)和实例分割(识别具有明确边界的对象)。
随着时间的推移,研究人员提出了全景分割作为统一语义分割和实例分割任务的方法。然而,即使取得了这些进展,仍然存在挑战。为全景分割设计的现有模型仍然需要为每个任务进行单独训练,充其量也只能算是半通用。
这就是 OneFormer 作为一个颠覆者的出现。它引入了一种新颖的方法——多任务通用架构。其理念是只训练一次这个框架,使用一个单一的通用架构、一个独立模型和一个数据集。OneFormer 的神奇之处在于,它在语义、实例和全景分割任务上都超越了专用框架。这一突破不仅仅是为了提高准确性;它是为了使图像分割更通用和高效。有了 OneFormer,对不同任务的大量资源和单独训练的需求成为过去。
OneFormer 的核心概念
现在,让我们分解一下 OneFormer 的主要特点,让它脱颖而出。
任务动态掩码
OneFormer 使用了一种名为“任务动态掩码”的巧妙技巧,以更好地理解和处理不同类型的图像分割任务。因此,当模型遇到图像时,它会使用此“任务动态掩码”来决定是关注整体场景,识别具有清晰边界的特定对象,还是两者兼而有之。
任务条件联合训练
OneFormer 的开创性功能之一是其任务条件联合训练策略。OneFormer 不再为语义、实例和全景分割进行单独训练,而是在训练期间统一采样任务。这种策略使模型能够同时学习和泛化不同的分割任务。通过任务令牌对特定任务进行架构条件设置,OneFormer 统一了训练过程,减少了对任务特定架构、模型和数据集的需求。这种创新方法显着简化了训练流程和资源要求。
查询-文本对比损失
最后,我们来谈谈“查询-文本对比损失”。可以将其视为 OneFormer 教导自己任务和类别之间差异的一种方式。在训练过程中,模型将其从图像中提取的特征(查询)与相应的文本描述(例如“一辆汽车的照片”)进行比较。这有助于模型理解每个任务的独特特征,并减少不同类别之间的混淆。OneFormer 的“任务动态掩码”使其能够像多任务助手一样灵活,而“查询-文本对比损失”通过比较视觉特征和文本描述来帮助它学习每个任务的具体细节。
通过结合这些核心概念,OneFormer 成为一种智能高效的图像分割工具,使过程更通用和易于访问。
结论
图片来自 OneFormer 论文
总之,OneFormer 框架代表了图像分割领域的一项突破性方法,旨在简化和统一各个领域的任务。与依赖于每个分割任务的专用架构的传统方法不同,OneFormer 引入了一种新颖的多任务通用架构,只需一个模型,在通用数据集上训练一次,即可超越现有框架。此外,在训练过程中引入查询-文本对比损失增强了模型学习任务间和类间差异的能力。OneFormer 利用基于 Transformer 的架构,受计算机视觉领域最近成功的启发,并引入了任务引导查询以提高任务敏感性。结果令人印象深刻,OneFormer 在 ADE20k、Cityscapes 和 COCO 等基准数据集上的语义、实例和全景分割任务中超越了最先进的模型。该框架的性能通过新的 ConvNeXt 和 DiNAT 主干进一步增强。
总而言之,OneFormer 代表了通用且易于访问的图像分割迈出了重要一步。通过引入一个能够处理各种分割任务的单一模型,该框架简化了分割过程并减少了资源需求。
模型使用示例
让我们看一个模型的使用示例。Dinat 主干需要 Natten 库,安装可能需要一段时间。
!pip install -q natten
下面我们可以看到根据不同分割类型进行的推理代码。
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
def run_segmentation(image, task_type):
"""Performs image segmentation based on the given task type.
Args:
image (PIL.Image): The input image.
task_type (str): The type of segmentation to perform ('semantic', 'instance', or 'panoptic').
Returns:
PIL.Image: The segmented image.
Raises:
ValueError: If the task type is invalid.
"""
processor = OneFormerProcessor.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
) # Load once here
model = OneFormerForUniversalSegmentation.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
)
if task_type == "semantic":
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
elif task_type == "instance":
inputs = processor(images=image, task_inputs=["instance"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_instance_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
elif task_type == "panoptic":
inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_panoptic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
else:
raise ValueError(
"Invalid task type. Choose from 'semantic', 'instance', or 'panoptic'"
)
return predicted_map
def show_image_comparison(image, predicted_map, segmentation_title):
"""Displays the original image and the segmented image side-by-side.
Args:
image (PIL.Image): The original image.
predicted_map (PIL.Image): The segmented image.
segmentation_title (str): The title for the segmented image.
"""
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_map)
plt.title(segmentation_title + " Segmentation")
plt.axis("off")
plt.show()
url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP errors
image = Image.open(response.raw)
task_to_run = "semantic"
predicted_map = run_segmentation(image, task_to_run)
show_image_comparison(image, predicted_map, task_to_run)