掩码生成
掩码生成是为图像生成语义上有意义的掩码的任务。此任务与图像分割非常相似,但存在许多差异。图像分割模型在标记数据集上进行训练,并且仅限于它们在训练期间看到的类别;给定图像,它们会返回一组掩码和相应的类别。
掩码生成模型在大量数据上进行训练,并以两种模式运行。
- 提示模式:在此模式下,模型接收图像和提示,其中提示可以是图像中对象内的二维点位置(XY 坐标)或围绕对象的边界框。在提示模式下,模型仅返回指向提示的对象上的掩码。
- 分割所有内容模式:在分割所有内容模式中,给定图像,模型会生成图像中的每个掩码。为此,会生成一个点网格并覆盖在图像上以进行推理。
掩码生成任务由分割任何模型 (SAM)支持。它是一个功能强大的模型,由基于视觉转换器的图像编码器、提示编码器和双向转换器掩码解码器组成。图像和提示被编码,解码器接收这些嵌入并生成有效的掩码。
SAM 作为分割的强大基础模型,因为它具有较大的数据覆盖范围。它在SA-1B上进行训练,这是一个包含 100 万张图像和 11 亿个掩码的数据集。
在本指南中,您将学习如何
- 使用批处理在分割所有内容模式下进行推理,
- 在点提示模式下进行推理,
- 在框提示模式下进行推理。
首先,让我们安装 transformers
pip install -q transformers
掩码生成管道
推断掩码生成模型最简单的方法是使用 mask-generation
管道。
>>> from transformers import pipeline
>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")
让我们看看图片。
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
让我们分割所有内容。points-per-batch
允许在分割所有内容模式下并行推理点。这可以加快推理速度,但会消耗更多内存。此外,SAM 仅允许对点进行批处理,而不是对图像进行批处理。pred_iou_thresh
是 IoU 置信度阈值,只有高于该阈值的掩码才会返回。
masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
masks
如下所示
{'masks': [array([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
'scores': tensor([0.9972, 0.9917,
...,
}
我们可以这样可视化它们
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
for i, mask in enumerate(masks["masks"]):
plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)
plt.axis('off')
plt.show()
以下是叠加了彩色地图的灰度原始图像。非常令人印象深刻。
模型推理
点提示
您也可以在没有管道的情况下使用模型。为此,请初始化模型和处理器。
from transformers import SamModel, SamProcessor
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
要进行点提示,请将输入点传递给处理器,然后获取处理器输出并将其传递给模型以进行推理。要对模型输出进行后处理,请传递输出以及我们从处理器的初始输出中获取的 original_sizes
和 reshaped_input_sizes
。我们需要传递这些,因为处理器会调整图像大小,并且需要外推输出。
input_points = [[[2592, 1728]]] # point location of the bee
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
我们可以可视化 masks
输出中的三个掩码。
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]
for i, mask in enumerate(mask_list, start=1):
overlayed_image = np.array(image).copy()
overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
axes[i].imshow(overlayed_image)
axes[i].set_title(f'Mask {i}')
for ax in axes:
ax.axis('off')
plt.show()
框提示
您也可以以类似于点提示的方式进行框提示。您可以简单地以列表 [x_min, y_min, x_max, y_max]
格式以及图像一起将输入框传递给 processor
。获取处理器输出并将其直接传递给模型,然后再次对输出进行后处理。
# bounding box around the bee
box = [2350, 1600, 2850, 2100]
inputs = processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
您可以可视化蜜蜂周围的边界框,如下所示。
import matplotlib.patches as patches
fig, ax = plt.subplots()
ax.imshow(image)
rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()
您可以查看下面的推理输出。
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)
ax.axis("off")
plt.show()