关键点检测
关键点检测识别并定位图像中特定的兴趣点。这些关键点,也称为地标,代表物体的有意义特征,例如面部特征或物体部件。这些模型接收图像输入并返回以下输出
- **关键点和得分**:兴趣点及其置信度得分。
- **描述符**:围绕每个关键点的图像区域的表示,捕获其纹理、梯度、方向和其他属性。
在本指南中,我们将展示如何从图像中提取关键点。
在本教程中,我们将使用 SuperPoint,一个用于关键点检测的基础模型。
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
让我们在下面的图像上测试模型。
import torch
from PIL import Image
import requests
import cv2
url_image_1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image_1 = Image.open(requests.get(url_image_1, stream=True).raw)
url_image_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"
image_2 = Image.open(requests.get(url_image_2, stream=True).raw)
images = [image_1, image_2]
我们现在可以处理输入并进行推断。
inputs = processor(images,return_tensors="pt").to(model.device, model.dtype)
outputs = model(**inputs)
模型输出包含批次中每个项目的相对关键点、描述符、掩码和得分。掩码突出显示图像中存在关键点的区域。
SuperPointKeypointDescriptionOutput(loss=None, keypoints=tensor([[[0.0437, 0.0167],
[0.0688, 0.0167],
[0.0172, 0.0188],
...,
[0.5984, 0.9812],
[0.6953, 0.9812]]]),
scores=tensor([[0.0056, 0.0053, 0.0079, ..., 0.0125, 0.0539, 0.0377],
[0.0206, 0.0058, 0.0065, ..., 0.0000, 0.0000, 0.0000]],
grad_fn=<CopySlices>), descriptors=tensor([[[-0.0807, 0.0114, -0.1210, ..., -0.1122, 0.0899, 0.0357],
[-0.0807, 0.0114, -0.1210, ..., -0.1122, 0.0899, 0.0357],
[-0.0807, 0.0114, -0.1210, ..., -0.1122, 0.0899, 0.0357],
...],
grad_fn=<CopySlices>), mask=tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 0, 0, 0]], dtype=torch.int32), hidden_states=None)
要绘制图像中的实际关键点,我们需要对输出进行后处理。为此,我们必须将实际图像大小以及输出传递给 `post_process_keypoint_detection`。
image_sizes = [(image.size[1], image.size[0]) for image in images]
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)
现在输出是一个字典列表,每个字典都是关键点、得分和描述符的处理后的输出。
[{'keypoints': tensor([[ 226, 57],
[ 356, 57],
[ 89, 64],
...,
[3604, 3391]], dtype=torch.int32),
'scores': tensor([0.0056, 0.0053, ...], grad_fn=<IndexBackward0>),
'descriptors': tensor([[-0.0807, 0.0114, -0.1210, ..., -0.1122, 0.0899, 0.0357],
[-0.0807, 0.0114, -0.1210, ..., -0.1122, 0.0899, 0.0357]],
grad_fn=<IndexBackward0>)},
{'keypoints': tensor([[ 46, 6],
[ 78, 6],
[422, 6],
[206, 404]], dtype=torch.int32),
'scores': tensor([0.0206, 0.0058, 0.0065, 0.0053, 0.0070, ...,grad_fn=<IndexBackward0>),
'descriptors': tensor([[-0.0525, 0.0726, 0.0270, ..., 0.0389, -0.0189, -0.0211],
[-0.0525, 0.0726, 0.0270, ..., 0.0389, -0.0189, -0.0211]}]
我们可以使用这些来绘制关键点。
import matplotlib.pyplot as plt
import torch
for i in range(len(images)):
keypoints = outputs[i]["keypoints"]
scores = outputs[i]["scores"]
descriptors = outputs[i]["descriptors"]
keypoints = outputs[i]["keypoints"].detach().numpy()
scores = outputs[i]["scores"].detach().numpy()
image = images[i]
image_width, image_height = image.size
plt.axis('off')
plt.imshow(image)
plt.scatter(
keypoints[:, 0],
keypoints[:, 1],
s=scores * 100,
c='cyan',
alpha=0.4
)
plt.show()
您可以在下面看到输出。