Transformers 文档
掩码生成
并获得增强的文档体验
开始使用
掩码生成
掩码生成是为图像生成语义上有意义的掩码的任务。此任务与图像分割非常相似,但存在许多差异。图像分割模型在标记数据集上进行训练,并且仅限于他们在训练期间看到的类别;给定图像,它们返回一组掩码和相应的类别。
掩码生成模型在大量数据上进行训练,并在两种模式下运行。
- 提示模式:在此模式下,模型接收图像和提示,其中提示可以是对象内图像中的 2D 点位置(XY 坐标)或对象周围的边界框。在提示模式下,模型仅返回提示所指向对象上的掩码。
- 分割所有内容模式:在分割所有内容模式下,给定图像,模型会生成图像中的每个掩码。为此,会生成一个点网格并覆盖在图像上以进行推理。
掩码生成任务由Segment Anything Model (SAM)支持。它是一个强大的模型,由基于 Vision Transformer 的图像编码器、提示编码器和双向 Transformer 掩码解码器组成。图像和提示被编码,解码器获取这些嵌入并生成有效的掩码。

SAM 作为分割的强大基础模型,因为它具有广泛的数据覆盖范围。它在 SA-1B 上进行训练,这是一个包含 100 万张图像和 11 亿个掩码的数据集。
在本指南中,您将学习如何
- 在分割所有内容模式下进行批量推理,
- 在点提示模式下进行推理,
- 在框提示模式下进行推理。
首先,让我们安装 transformers
pip install -q transformers
掩码生成 Pipeline
推断掩码生成模型的最简单方法是使用 mask-generation
pipeline。
>>> from transformers import pipeline
>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")
让我们看看图像。
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

让我们分割所有内容。points-per-batch
允许在分割所有内容模式下并行推断点。这可以实现更快的推理,但会消耗更多内存。此外,SAM 仅允许对点而不是图像进行批量处理。pred_iou_thresh
是 IoU 置信度阈值,仅返回高于该特定阈值的掩码。
masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
masks
看起来像这样
{'masks': [array([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
'scores': tensor([0.9972, 0.9917,
...,
}
我们可以这样可视化它们
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
for i, mask in enumerate(masks["masks"]):
plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)
plt.axis('off')
plt.show()
下面是以灰度显示的原始图像,并覆盖了彩色地图。非常 впечатляет。

模型推理
点提示
您也可以在不使用 pipeline 的情况下使用模型。为此,请初始化模型和处理器。
from transformers import SamModel, SamProcessor
import torch
from accelerate.test_utils.testing import get_backend
# automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
device, _, _ = get_backend()
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
要进行点提示,请将输入点传递给处理器,然后获取处理器输出并将其传递给模型以进行推理。要对模型输出进行后处理,请传递我们从处理器的初始输出中获取的输出以及 original_sizes
和 reshaped_input_sizes
。我们需要传递这些,因为处理器会调整图像大小,并且输出需要外推。
input_points = [[[2592, 1728]]] # point location of the bee
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
我们可以可视化 masks
输出中的三个掩码。
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]
for i, mask in enumerate(mask_list, start=1):
overlayed_image = np.array(image).copy()
overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
axes[i].imshow(overlayed_image)
axes[i].set_title(f'Mask {i}')
for ax in axes:
ax.axis('off')
plt.show()

框提示
您也可以使用与点提示类似的方式进行框提示。您只需将列表格式 [x_min, y_min, x_max, y_max]
的输入框与图像一起传递给 processor
。获取处理器输出并直接将其传递给模型,然后再次对输出进行后处理。
# bounding box around the bee
box = [2350, 1600, 2850, 2100]
inputs = processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
您可以可视化蜜蜂周围的边界框,如下所示。
import matplotlib.patches as patches
fig, ax = plt.subplots()
ax.imshow(image)
rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()

您可以在下面看到推理输出。
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)
ax.axis("off")
plt.show()
