检测Transformer (DETR)
架构概述
DETR 主要用于目标检测任务,即检测图像中的物体。例如,模型的输入可以是道路的图像,模型的输出可以是 [('汽车',X1,Y1,W1,H1),('行人',X2,Y2,W2,H2)]
,其中 X、Y、W、H 代表 x、y 坐标,表示边界框的位置,以及框的宽度和高度。 像 YOLO 这样的传统目标检测模型包含手工设计的特征,例如锚框先验,需要对物体位置和形状进行初始猜测,这会影响后续训练。然后使用后处理步骤来去除重叠的边界框,这需要仔细选择其过滤启发式方法。检测Transformer,简称 DETR,通过在特征提取主干之后使用编码器-解码器Transformer 来直接并行预测边界框,从而简化了检测器,只需要最小的后处理。
DETR 的模型架构以 CNN 主干开始,类似于其他基于图像的网络,其输出经过处理后被馈送到 Transformer 编码器中,生成 N 个嵌入。编码器嵌入被添加到学习到的位置嵌入(称为对象查询)中,并用于 Transformer 解码器,生成另外 N 个嵌入。作为最后一步,每个 N 个嵌入都经过单独的前馈层,以预测边界框的宽度、高度、坐标以及物体类别(或是否存在物体)。
主要特征
编码器-解码器
与其他Transformer 一样,Transformer 编码器期望 CNN 主干的输出是序列。因此,大小为 [维度,高度,宽度]
的特征图被缩小,然后展平为 [维度,小于高度 x 宽度]
。 左:可视化了特征图中的 256 个维度中的 12 个。每个维度提取原始猫图像的某些特征,同时缩小原始图像。一些维度对猫上的图案更加关注;一些维度对床单更加关注。 右:保留大小为 256 的原始特征维度,宽度和高度被进一步缩小并展平为大小 850。
由于Transformer 是置换不变的,因此在编码器和解码器中都添加了位置嵌入,以提醒模型嵌入来自图像上的哪个位置。在编码器中,使用固定位置编码,而在解码器中,使用学习到的位置编码(对象查询)。固定编码类似于原始Transformer 论文中使用的编码,其中编码由不同频率的正弦函数在不同特征维度上定义。它提供了位置感,而无需任何学习参数,由图像上的位置索引。学习到的编码也由位置索引,但每个位置都有一个单独的编码,在整个训练过程中学习,以模型理解的方式表示位置。
基于集合的全局损失函数
在 YOLO(一种流行的目标检测模型)中,损失函数包含边界框、物体性(即物体在感兴趣区域中存在的概率)和类别损失。损失是在每个网格单元的多个边界框上计算的,这些边界框的数量是固定的。另一方面,在 DETR 中,架构预计会以置换不变的方式生成唯一的边界框(即,检测的顺序在输出中无关紧要,并且边界框必须有所不同,不能都相同)。因此,需要匹配来评估预测的准确性。
二分匹配
二分匹配是一种在真实边界框和预测框之间计算一对一匹配的方法。它找到真实框和预测框之间相似度最高的匹配,以及类别。这确保最接近的预测与相应的真实框相匹配,以便在损失函数中正确调整框和类别。如果没有进行匹配,则即使预测是正确的,与真实框顺序不一致的预测也会被标记为错误。
使用 DETR 进行目标检测
要查看如何使用 Hugging Face transformers 对 DETR 进行推理的示例,请参见 DETR.ipynb
。
DETR 的演变
可变形 DETR
DETR 的两个主要问题是收敛过程缓慢且速度慢以及小目标检测效果不佳。可变形注意力
第一个问题通过使用可变形注意力解决,可变形注意力减少了需要关注的采样点数量。由于全局注意力,传统注意力效率低下,并且严重限制了图像的解析度。该模型仅关注每个参考点周围的固定数量的采样点,并且参考点是模型根据输入学习的。例如,在一张狗的图像中,参考点可能在狗的中心,采样点靠近耳朵、嘴巴、尾巴等。
多尺度可变形注意力模块
第二个问题与 YOLOv3 的解决方法类似,其中引入了多尺度特征图。在卷积神经网络中,较早的层提取较小的细节(例如线条),而较后的层提取较大的细节(例如车轮、耳朵)。类似地,可变形注意力的不同层会导致不同的解析度级别。通过将编码器中一些层的输出连接到解码器,它允许模型检测各种尺寸的目标。
条件 DETR
条件 DETR 还旨在解决原始 DETR 中训练收敛速度慢的问题,从而使收敛速度提高了 6.7 倍以上。作者发现,目标查询是通用的,而不是特定于输入图像的。在解码器中使用条件交叉注意力,查询可以更好地定位用于边界框回归的区域。 左:DETR 解码器层。右:可变形 DETR 解码器层
上图比较了原始 DETR 和可变形 DETR 解码器层,主要区别在于交叉注意力块的查询输入。作者区分了内容查询 cq(解码器自注意力输出)和空间查询 pq。原始 DETR 只将它们加在一起。在可变形 DETR 中,它们被连接在一起,其中 cq 关注目标的内容,而 pq 关注边界框区域。
空间查询 pq 是解码器嵌入和目标查询投影到同一空间(分别变为 T 和 ps)并将它们相乘的结果。以前的层的解码器嵌入包含边界框区域的信息,而目标查询包含每个边界框的学习参考点的信息。因此,它们的投影结合成一个表示,允许交叉注意力测量它们与编码器输入和正弦位置嵌入的相似性。这比仅使用目标查询和固定参考点的 DETR 更有效。
DETR 推理
您可以使用 Hugging Face Hub 上现有的 DETR 模型进行推理,如下所示
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# initialize the model
processor = DetrImageProcessor.from_pretrained(
"facebook/detr-resnet-101", revision="no_timm"
)
model = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-101", revision="no_timm"
)
# preprocess the inputs and infer
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# non max supression above 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
输出如下。
Detected cat with confidence 0.998 at location [344.06, 24.85, 640.34, 373.74]
Detected remote with confidence 0.997 at location [328.13, 75.93, 372.81, 187.66]
Detected remote with confidence 0.997 at location [39.34, 70.13, 175.56, 118.78]
Detected cat with confidence 0.998 at location [15.36, 51.75, 316.89, 471.16]
Detected couch with confidence 0.995 at location [-0.19, 0.71, 639.73, 474.17]
DETR 的 PyTorch 实现
下面显示了原始论文中 DETR 的实现
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(
self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers
):
super().__init__()
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers
)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)
h = self.transformer(
pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)
)
return self.linear_class(h), self.linear_bbox(h).sigmoid()
逐行遍历前向函数:
骨干网络
输入图像首先通过 ResNet 骨干网络,然后通过一个卷积层,该层将维度缩减到 hidden_dim
x = self.backbone(inputs) h = self.conv(x)
它们在 __init__
函数中声明
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
位置嵌入
虽然在论文中分别在编码器和解码器中使用了固定嵌入和训练嵌入,但为了简单起见,作者在实现中对两者都使用了训练嵌入。
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)
它们在此处声明为 nn.Parameter
。行嵌入和列嵌入组合表示图像中的位置。
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
调整大小
在进入 Transformer 之前,大小为 (批次大小,隐藏维度,H,W)
的特征被重新整形为 (隐藏维度,批次大小,H*W)
。这使它们成为 Transformer 的顺序输入
h.flatten(2).permute(2, 0, 1)
Transformernn.Transformer
函数将第一个参数作为编码器的输入,并将第二个参数作为编码器的输入。如您所见,编码器接受重新调整大小的特征加上位置嵌入,而解码器接受 query_pos
,即解码器位置嵌入。
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1))
前馈网络
最后,输出(大小为 (query_pos_dim,批次大小,隐藏维度)
的张量)通过两个线性层。
return self.linear_class(h), self.linear_bbox(h).sigmoid()
其中第一个预测类别。为 无目标
类添加了另一个类别
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
第二个线性层预测边界框,输出大小为 4,用于 xy 坐标、高度和宽度。
self.linear_bbox = nn.Linear(hidden_dim, 4)