社区计算机视觉课程文档
用于物体检测的 Vision Transformer
并获得增强的文档体验
开始使用
用于物体检测的 Vision Transformer
本节将介绍如何使用 Vision Transformer 实现物体检测任务。我们将了解如何为我们的用例微调现有的预训练物体检测模型。在开始之前,请查看这个 HuggingFace Space,您可以在其中试用最终输出。
简介
物体检测是一项计算机视觉任务,涉及识别和定位图像或视频中的物体。它包含两个主要步骤
- 首先,识别存在的物体类型(例如汽车、人或动物)。
- 其次,通过在物体周围绘制边界框来确定它们的精确位置。
这些模型通常接收图像(静态图像或视频帧)作为输入,每个图像中存在多个物体。例如,考虑一个包含多个物体的图像,例如汽车、人、自行车等等。在处理输入后,这些模型会生成一组数字,传达以下信息
- 物体的位置(边界框的 XY 坐标)。
- 物体的类别。
物体检测有很多应用。最重要的例子之一是在自动驾驶领域,其中物体检测用于检测汽车周围的不同物体(如行人、道路标志、交通信号灯等),这些物体成为做出决策的输入之一。
为了加深您对物体检测来龙去脉的理解,请查看我们关于物体检测 🤗 的专门章节。
物体检测中微调模型的必要性 🤔
您应该构建新模型,还是修改现有模型? 这是一个很棒的问题。 从头开始训练物体检测模型意味着
- 一遍又一遍地重复已经完成的研究。
- 编写重复的模型代码、训练它们,并为不同的用例维护不同的存储库。
- 大量的实验和资源浪费。
与其做这一切,不如采用性能良好的预训练模型(一种在识别通用特征方面做得非常出色的模型),并调整或重新调整其权重(或其权重的某些部分)以使其适应您的用例。我们相信或假设预训练模型已经学到了足够的知识来提取图像内部的重要特征,从而定位和分类物体。因此,如果引入新物体,则可以对同一模型进行短时间的训练和计算,以开始借助已学习和新特征来检测这些新物体。
在本教程结束时,您应该能够为物体检测用例创建一个完整的流水线(从加载数据集、微调模型到执行推理)。
安装必要的库
让我们从安装开始。 只需执行以下单元格即可安装必要的软件包。 在本教程中,我们将使用 Hugging Face Transformers 和 PyTorch。
!pip install -U -q datasets transformers[torch] evaluate timm albumentations accelerate
场景
为了使本教程更有趣,让我们考虑一个真实的例子。考虑以下场景:建筑工人在建筑区域工作时需要极高的安全性。基本安全协议要求每次都佩戴头盔。由于建筑工人很多,因此很难时刻关注每个人。
但是,如果我们有一个摄像头系统,可以实时检测人员以及人员是否佩戴头盔,那不是很棒吗?
因此,我们将微调一个轻量级物体检测模型来做到这一点。让我们深入了解一下。
数据集
对于上述场景,我们将使用hardhat数据集,该数据集由中国东北大学提供。我们可以使用 🤗 datasets
下载并加载此数据集。
from datasets import load_dataset
dataset = load_dataset("anindya64/hardhat")
dataset
这将为您提供以下数据结构
DatasetDict({
train: Dataset({
features: ['image', 'image_id', 'width', 'height', 'objects'],
num_rows: 5297
})
test: Dataset({
features: ['image', 'image_id', 'width', 'height', 'objects'],
num_rows: 1766
})
})
上面是一个 DatasetDict,它是一个高效的类字典结构,包含训练集和测试集中的整个数据集。 如您所见,在每个拆分(训练和测试)下,我们都有 features
和 num_rows
。 在特征下,我们有 image
,一个 Pillow 对象,图像的 id、高度和宽度以及物体。 现在让我们看看每个数据点(在训练/测试集中)是什么样的。 为此,请运行以下行
dataset["train"][0]
这将为您提供以下结构
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375>,
'image_id': 1,
'width': 500,
'height': 375,
'objects': {'id': [1, 1],
'area': [3068.0, 690.0],
'bbox': [[178.0, 84.0, 52.0, 59.0], [111.0, 144.0, 23.0, 30.0]],
'category': ['helmet', 'helmet']}}
如您所见,objects
是另一个字典,其中包含物体 id(此处为类别 id)、物体的面积以及边界框坐标 (bbox
) 和类别(或标签)。 以下是数据元素的每个键和值的更详细说明。
image
:这是一个 Pillow Image 对象,有助于在甚至从路径加载之前直接查看图像。image_id
:表示图像编号来自训练文件。width
:图像的宽度。height
:图像的高度。objects
:另一个包含有关注释信息的字典。 这包含以下内容id
:一个列表,列表的长度表示物体的数量,每个值表示类别索引。area
:物体的面积。bbox
:表示物体的边界框坐标。category
:物体的类别(字符串)。
现在让我们正确提取训练样本和测试样本。 对于本教程,我们大约有 5000 个训练样本和 1700 个测试样本。
# First, extract out the train and test set
train_dataset = dataset["train"]
test_dataset = dataset["test"]
现在我们知道了样本数据点包含什么,让我们首先绘制该样本。 在这里,我们将首先绘制图像,然后再绘制相应的边界框。
以下是我们将要做的
- 获取图像及其对应的高度和宽度。
- 创建一个绘图对象,可以轻松地在图像上绘制文本和线条。
- 从样本中获取注释字典。
- 遍历它。
- 对于每个,获取边界框坐标,即 x(边界框水平开始的位置)、y(边界框垂直开始的位置)、w(边界框的宽度)、h(边界框的高度)。
- 现在,如果边界框尺寸已标准化,则缩放它,否则保持原样。
- 最后绘制矩形和类别文本。
import numpy as np
from PIL import Image, ImageDraw
def draw_image_from_idx(dataset, idx):
sample = dataset[idx]
image = sample["image"]
annotations = sample["objects"]
draw = ImageDraw.Draw(image)
width, height = sample["width"], sample["height"]
for i in range(len(annotations["id"])):
box = annotations["bbox"][i]
class_idx = annotations["id"][i]
x, y, w, h = tuple(box)
if max(box) > 1.0:
x1, y1 = int(x), int(y)
x2, y2 = int(x + w), int(y + h)
else:
x1 = int(x * width)
y1 = int(y * height)
x2 = int((x + w) * width)
y2 = int((y + h) * height)
draw.rectangle((x1, y1, x2, y2), outline="red", width=1)
draw.text((x1, y1), annotations["category"][i], fill="white")
return image
draw_image_from_idx(dataset=train_dataset, idx=10)
我们有一个绘制单个图像的函数,让我们使用上面的函数编写一个简单的函数来绘制多个图像。 这将帮助我们进行一些分析。
import matplotlib.pyplot as plt
def plot_images(dataset, indices):
"""
Plot images and their annotations.
"""
num_rows = len(indices) // 3
num_cols = 3
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10))
for i, idx in enumerate(indices):
row = i // num_cols
col = i % num_cols
# Draw image
image = draw_image_from_idx(dataset, idx)
# Display image on the corresponding subplot
axes[row, col].imshow(image)
axes[row, col].axis("off")
plt.tight_layout()
plt.show()
# Now use the function to plot images
plot_images(train_dataset, range(9))
运行该函数将为我们提供如下图所示的精美拼贴画。
AutoImageProcessor
在微调模型之前,我们必须对数据进行预处理,使其与预训练时使用的方法完全匹配。 HuggingFace AutoImageProcessor 负责处理图像数据以创建 pixel_values
、pixel_mask
和 labels
,DETR 模型可以使用它们进行训练。
现在,让我们从我们想要使用模型微调的同一检查点实例化图像处理器。
from transformers import AutoImageProcessor
checkpoint = "facebook/detr-resnet-50-dc5"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
预处理数据集
在将图像传递给 image_processor
之前,我们还对图像及其对应的边界框应用不同类型的增强。
简单来说,增强是一组随机变换,例如旋转、调整大小等。应用这些是为了获得更多样本,并使视觉模型对图像的不同条件更加鲁棒。我们将使用 albumentations 库来实现此目的。它允许您创建图像的随机变换,以便增加训练的样本量。
import albumentations
import numpy as np
import torch
transform = albumentations.Compose(
[
albumentations.Resize(480, 480),
albumentations.HorizontalFlip(p=1.0),
albumentations.RandomBrightnessContrast(p=1.0),
],
bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
)
一旦我们初始化所有变换,我们需要创建一个函数来格式化注释,并返回具有非常特定格式的注释列表。
这是因为 image_processor
期望注释采用以下格式:{'image_id': int, 'annotations': List[Dict]}
,其中每个字典都是 COCO 物体注释。
def formatted_anns(image_id, category, area, bbox):
annotations = []
for i in range(0, len(category)):
new_ann = {
"image_id": image_id,
"category_id": category[i],
"isCrowd": 0,
"area": area[i],
"bbox": list(bbox[i]),
}
annotations.append(new_ann)
return annotations
最后,我们将图像和注释变换结合起来,对整个数据集批次进行变换。
以下是执行此操作的最终代码
# transforming a batch
def transform_aug_ann(examples):
image_ids = examples["image_id"]
images, bboxes, area, categories = [], [], [], []
for image, objects in zip(examples["image"], examples["objects"]):
image = np.array(image.convert("RGB"))[:, :, ::-1]
out = transform(image=image, bboxes=objects["bbox"], category=objects["id"])
area.append(objects["area"])
images.append(out["image"])
bboxes.append(out["bboxes"])
categories.append(out["category"])
targets = [
{"image_id": id_, "annotations": formatted_anns(id_, cat_, ar_, box_)}
for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
]
return image_processor(images=images, annotations=targets, return_tensors="pt")
最后,您只需将此预处理函数应用于整个数据集即可。 您可以使用 HuggingFace 🤗 Datasets with transform 方法来实现此目的。
# Apply transformations for both train and test dataset
train_dataset_transformed = train_dataset.with_transform(transform_aug_ann)
test_dataset_transformed = test_dataset.with_transform(transform_aug_ann)
现在让我们看看转换后的训练数据集样本是什么样的
train_dataset_transformed[0]
这将返回一个张量字典。 我们这里主要需要的是代表图像的 pixel_values
、注意力掩码 pixel_mask
和 labels
。 这是一个数据点的样子
{'pixel_values': tensor([[[-0.1657, -0.1657, -0.1657, ..., -0.3369, -0.4739, -0.5767],
[-0.1657, -0.1657, -0.1657, ..., -0.3369, -0.4739, -0.5767],
[-0.1657, -0.1657, -0.1828, ..., -0.3541, -0.4911, -0.5938],
...,
[-0.4911, -0.5596, -0.6623, ..., -0.7137, -0.7650, -0.7993],
[-0.4911, -0.5596, -0.6794, ..., -0.7308, -0.7993, -0.8335],
[-0.4911, -0.5596, -0.6794, ..., -0.7479, -0.8164, -0.8507]],
[[-0.0924, -0.0924, -0.0924, ..., 0.0651, -0.0749, -0.1800],
[-0.0924, -0.0924, -0.0924, ..., 0.0651, -0.0924, -0.2150],
[-0.0924, -0.0924, -0.1099, ..., 0.0476, -0.1275, -0.2500],
...,
[-0.0924, -0.1800, -0.3200, ..., -0.4426, -0.4951, -0.5301],
[-0.0924, -0.1800, -0.3200, ..., -0.4601, -0.5126, -0.5651],
[-0.0924, -0.1800, -0.3200, ..., -0.4601, -0.5301, -0.5826]],
[[ 0.1999, 0.1999, 0.1999, ..., 0.6705, 0.5136, 0.4091],
[ 0.1999, 0.1999, 0.1999, ..., 0.6531, 0.4962, 0.3916],
[ 0.1999, 0.1999, 0.1825, ..., 0.6356, 0.4614, 0.3568],
...,
[ 0.4788, 0.3916, 0.2696, ..., 0.1825, 0.1302, 0.0953],
[ 0.4788, 0.3916, 0.2696, ..., 0.1651, 0.0953, 0.0605],
[ 0.4788, 0.3916, 0.2696, ..., 0.1476, 0.0779, 0.0431]]]),
'pixel_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]),
'labels': {'size': tensor([800, 800]), 'image_id': tensor([1]), 'class_labels': tensor([1, 1]), 'boxes': tensor([[0.5920, 0.3027, 0.1040, 0.1573],
[0.7550, 0.4240, 0.0460, 0.0800]]), 'area': tensor([8522.2217, 1916.6666]), 'iscrowd': tensor([0, 0]), 'orig_size': tensor([480, 480])}}
我们快到了 🚀。 作为最后一个预处理步骤,我们需要编写一个自定义 collate_fn
。 现在什么是 collate_fn
?
collate_fn
负责从数据集中获取样本列表,并将它们转换为适合模型输入格式的批次。
一般来说,DataCollator
通常执行诸如填充、截断等任务。 在自定义整理函数中,我们经常定义我们想要如何以及如何将数据分组到批次中,或者简单地说,如何表示每个批次。
数据整理器主要将数据放在一起,然后对其进行预处理。 让我们制作我们的整理函数。
def collate_fn(batch):
pixel_values = [item["pixel_values"] for item in batch]
encoding = image_processor.pad(pixel_values, return_tensors="pt")
labels = [item["labels"] for item in batch]
batch = {}
batch["pixel_values"] = encoding["pixel_values"]
batch["pixel_mask"] = encoding["pixel_mask"]
batch["labels"] = labels
return batch
训练 DETR 模型。
到目前为止,所有的繁重工作都已完成。 现在,剩下要做的就是将拼图的每个部分逐个组装起来。 开始吧!
训练过程包括以下步骤
使用与预处理中相同的检查点,通过 AutoModelForObjectDetection 加载基础(预训练)模型。
在 TrainingArguments 中定义所有超参数和其他参数。
将训练参数传递到 HuggingFace Trainer 中,以及模型、数据集和图像。
调用
train()
方法并微调您的模型。
从您用于预处理的同一检查点加载模型时,请记住传递您先前从数据集的元数据创建的
label2id
和id2label
映射。 此外,我们指定ignore_mismatched_sizes=True
以将现有分类头替换为新的分类头。
from transformers import AutoModelForObjectDetection
id2label = {0: "head", 1: "helmet", 2: "person"}
label2id = {v: k for k, v in id2label.items()}
model = AutoModelForObjectDetection.from_pretrained(
checkpoint,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
在继续之前,请登录 Hugging Face Hub 以在训练时动态上传您的模型。 这样,您无需处理检查点并将其保存在某处。
from huggingface_hub import notebook_login
notebook_login()
完成后,让我们开始训练模型。 我们首先定义训练参数并定义一个训练器对象,该对象使用这些参数进行训练,如下所示
from transformers import TrainingArguments
from transformers import Trainer
# Define the training arguments
training_args = TrainingArguments(
output_dir="detr-resnet-50-hardhat-finetuned",
per_device_train_batch_size=8,
num_train_epochs=3,
max_steps=1000,
fp16=True,
save_steps=10,
logging_steps=30,
learning_rate=1e-5,
weight_decay=1e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=True,
)
# Define the trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset_transformed,
eval_dataset=test_dataset_transformed,
tokenizer=image_processor,
)
trainer.train()
训练完成后,您现在可以删除模型,因为检查点已上传到 HuggingFace Hub 中。
del model
torch.cuda.synchronize()
测试和推理
现在我们将尝试对我们新的微调模型进行推理。 在本教程中,我们将针对此图像进行测试
在这里,我们首先编写一个非常简单的代码,用于对一些新图像进行物体检测推理。 我们从对单个图像进行推理开始,然后我们将所有内容汇总在一起并从中创建一个函数。
import requests
from transformers import pipeline
# download a sample image
url = "https://huggingface.co/datasets/hf-vision/course-assets/resolve/main/test-helmet-object-detection.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# make the object detection pipeline
obj_detector = pipeline(
"object-detection", model="anindya64/detr-resnet-50-dc5-hardhat-finetuned"
)
results = obj_detector(train_dataset[0]["image"])
print(results)
现在让我们创建一个非常简单的函数来在我们的图像上绘制结果。 我们从结果中获取分数、标签和相应的边界框坐标,我们将使用它们在图像中绘制。
def plot_results(image, results, threshold=0.7):
image = Image.fromarray(np.uint8(image))
draw = ImageDraw.Draw(image)
for result in results:
score = result["score"]
label = result["label"]
box = list(result["box"].values())
if score > threshold:
x, y, x2, y2 = tuple(box)
draw.rectangle((x, y, x2, y2), outline="red", width=1)
draw.text((x, y), label, fill="white")
draw.text(
(x + 0.5, y - 0.5),
text=str(score),
fill="green" if score > 0.7 else "red",
)
return image
最后,对我们使用的同一测试图像使用此函数。
results = obj_detector(image)
plot_results(image, results)
这将绘制以下输出
现在,让我们将所有内容汇总到一个简单的函数中。
def predict(image, pipeline, threshold=0.7):
results = pipeline(image)
return plot_results(image, results, threshold)
# Let's test for another test image
img = test_dataset[0]["image"]
predict(img, obj_detector)
让我们甚至在一个小的测试样本上使用我们的推理函数绘制多个图像。
from tqdm.auto import tqdm
def plot_images(dataset, indices):
"""
Plot images and their annotations.
"""
num_rows = len(indices) // 3
num_cols = 3
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10))
for i, idx in tqdm(enumerate(indices), total=len(indices)):
row = i // num_cols
col = i % num_cols
# Draw image
image = predict(dataset[idx]["image"], obj_detector)
# Display image on the corresponding subplot
axes[row, col].imshow(image)
axes[row, col].axis("off")
plt.tight_layout()
plt.show()
plot_images(test_dataset, range(6))
运行此函数将为我们提供如下所示的输出
嗯,这还不错。 如果我们进一步微调,我们可以改进结果。 您可以在此处找到此微调检查点。
< > 更新 在 GitHub 上