Transformers 文档

掩码生成

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

掩码生成

掩码生成是为图像生成语义上有意义的掩码的任务。此任务与图像分割非常相似,但存在许多差异。图像分割模型在标记数据集上进行训练,并且仅限于它们在训练期间看到的类别;给定图像,它们会返回一组掩码和相应的类别。

掩码生成模型在大量数据上进行训练,并以两种模式运行。

  • 提示模式:在此模式下,模型接收图像和提示,其中提示可以是图像中对象内的二维点位置(XY 坐标)或围绕对象的边界框。在提示模式下,模型仅返回指向提示的对象上的掩码。
  • 分割所有内容模式:在分割所有内容模式中,给定图像,模型会生成图像中的每个掩码。为此,会生成一个点网格并覆盖在图像上以进行推理。

掩码生成任务由分割任何模型 (SAM)支持。它是一个功能强大的模型,由基于视觉转换器的图像编码器、提示编码器和双向转换器掩码解码器组成。图像和提示被编码,解码器接收这些嵌入并生成有效的掩码。

SAM Architecture

SAM 作为分割的强大基础模型,因为它具有较大的数据覆盖范围。它在SA-1B上进行训练,这是一个包含 100 万张图像和 11 亿个掩码的数据集。

在本指南中,您将学习如何

  • 使用批处理在分割所有内容模式下进行推理,
  • 在点提示模式下进行推理,
  • 在框提示模式下进行推理。

首先,让我们安装 transformers

pip install -q transformers

掩码生成管道

推断掩码生成模型最简单的方法是使用 mask-generation 管道。

>>> from transformers import pipeline

>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")

让我们看看图片。

from PIL import Image
import requests

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
Example Image

让我们分割所有内容。points-per-batch 允许在分割所有内容模式下并行推理点。这可以加快推理速度,但会消耗更多内存。此外,SAM 仅允许对点进行批处理,而不是对图像进行批处理。pred_iou_thresh 是 IoU 置信度阈值,只有高于该阈值的掩码才会返回。

masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)

masks 如下所示

{'masks': [array([[False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
'scores': tensor([0.9972, 0.9917,
        ...,
}

我们可以这样可视化它们

import matplotlib.pyplot as plt

plt.imshow(image, cmap='gray')

for i, mask in enumerate(masks["masks"]):
    plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)

plt.axis('off')
plt.show()

以下是叠加了彩色地图的灰度原始图像。非常令人印象深刻。

Visualized

模型推理

点提示

您也可以在没有管道的情况下使用模型。为此,请初始化模型和处理器。

from transformers import SamModel, SamProcessor
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

要进行点提示,请将输入点传递给处理器,然后获取处理器输出并将其传递给模型以进行推理。要对模型输出进行后处理,请传递输出以及我们从处理器的初始输出中获取的 original_sizesreshaped_input_sizes。我们需要传递这些,因为处理器会调整图像大小,并且需要外推输出。

input_points = [[[2592, 1728]]] # point location of the bee

inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

我们可以可视化 masks 输出中的三个掩码。

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 4, figsize=(15, 5))

axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]

for i, mask in enumerate(mask_list, start=1):
    overlayed_image = np.array(image).copy()

    overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
    overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
    overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
    
    axes[i].imshow(overlayed_image)
    axes[i].set_title(f'Mask {i}')
for ax in axes:
    ax.axis('off')

plt.show()
Visualized

框提示

您也可以以类似于点提示的方式进行框提示。您可以简单地以列表 [x_min, y_min, x_max, y_max] 格式以及图像一起将输入框传递给 processor。获取处理器输出并将其直接传递给模型,然后再次对输出进行后处理。

# bounding box around the bee
box = [2350, 1600, 2850, 2100]

inputs = processor(
        image,
        input_boxes=[[[box]]],
        return_tensors="pt"
    ).to("cuda")

with torch.no_grad():
    outputs = model(**inputs)

mask = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()

您可以可视化蜜蜂周围的边界框,如下所示。

import matplotlib.patches as patches

fig, ax = plt.subplots()
ax.imshow(image)

rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()
Visualized Bbox

您可以查看下面的推理输出。

fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)

ax.axis("off")
plt.show()
Visualized Inference
< > 在 GitHub 上更新