社区计算机视觉课程文档

扩张邻域注意力 Transformer (DINAT)

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

扩张邻域注意力 Transformer (DINAT)

DINAT Architecture Diagram

架构概述

扩张邻域注意力 Transformer (DiNAT) 是一种创新的分层视觉 transformer,旨在提升深度学习模型的性能,尤其是在视觉识别任务中。与传统的 transformers 不同,后者采用自注意力机制,随着模型规模的扩大,计算成本可能会变得很高,DiNAT 引入了扩张邻域注意力 (DiNA)。DiNA 通过结合稀疏的全局注意力,扩展了名为邻域注意力 (NA) 的局部注意力机制,而没有增加额外的计算开销。这种扩展使 DiNA 能够捕获更全局的上下文,指数级地扩展感受野,并有效地建模更长距离的相互依赖关系。

DiNAT 在其架构中结合了 NA 和 DiNA,从而产生了一种 transformer 模型,该模型能够保持局部性、保持平移等变性,并在下游视觉任务中实现显著的性能提升。使用 DiNAT 进行的实验证明了其在各种视觉识别任务中优于强大的基线模型,如 NAT、Swin 和 ConvNeXt。

DiNAT 的核心:邻域注意力

DiNAT 基于邻域注意力 (NA) 架构,这是一种专门为计算机视觉任务设计的注意力机制,旨在有效地捕获图像中像素之间的关系。简单来说,想象你有一张图像,图像中的每个像素都需要理解并关注其附近的像素,以便理解整张图片。让我们来看看 NA 的关键特征

  • 局部关系:NA 捕获局部关系,允许每个像素考虑来自其直接周围环境的信息。这类似于我们如何通过首先查看离我们最近的物体,然后再考虑整个视图来理解场景。

  • 感受野:NA 允许像素在不需要太多额外计算的情况下,扩展他们对其周围环境的理解。它动态地扩展他们的范围或“注意力跨度”,以便在必要时包括更远的邻居。

本质上,邻域注意力是一种使图像中的像素能够专注于其周围环境的技术,帮助他们有效地理解局部关系。这种局部理解有助于建立对整个图像的详细理解,同时有效地管理计算资源。

DiNAT Architecture Diagram

DiNAT 的演进

扩张邻域注意力 Transformer 的发展代表了视觉 transformers 的一个重大改进。它解决了现有注意力机制的局限性。最初,引入邻域注意力是为了提供局部性和效率,但它在捕获全局上下文方面有所不足。为了克服这一局限性,引入了扩张邻域注意力 (DiNA) 的概念。DiNA 通过将邻域扩展到更大的稀疏区域来扩展 NA。这允许捕获更多的全局上下文,并指数级地增加感受野,而不会增加计算负担。接下来的发展是 DiNAT,它将局部化的 NA 与 DiNA 的扩展全局上下文相结合。DiNAT 通过在整个模型中逐步改变扩张率、优化感受野和简化特征学习来实现这一点。

使用 DiNAT 进行图像分类

您可以使用 🤗 transformers 的 shi-labs/dinat-mini-in1k-224 模型对 ImageNet-1k 图像进行分类。您也可以针对自己的用例对其进行微调。

from transformers import AutoImageProcessor, DinatForImageClassification
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoImageProcessor.from_pretrained("shi-labs/dinat-mini-in1k-224")
model = DinatForImageClassification.from_pretrained("shi-labs/dinat-mini-in1k-224")

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

参考文献

< > GitHub 上更新