社区计算机视觉课程文档

MobileViT v2

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始

MobileViT v2

先前讨论的 Vision Transformer 架构计算密集,难以在移动设备上运行。先前的最先进架构使用 CNN 进行移动视觉任务。然而,CNN 无法学习全局表示,因此它们的性能比 Transformer 同类产品差。

MobileViT 架构旨在解决视觉移动任务所需的低延迟和轻量级架构等问题,同时提供 Transformer 和 CNN 的优势。MobileViT 架构由 Apple 开发,并基于 Google 研究团队的 MobileNet 构建。MobileViT 架构通过添加 MobileViT 块和可分离自注意力,在之前的 MobileNet 架构基础上构建。这两个特性实现了闪电般的低延迟、参数减少、计算复杂性降低以及在资源受限设备上部署视觉 ML 模型。

MobileViT 架构

Sachin Mehta 和 Mohammad Rastegari 在论文“MobileViT:轻量级、通用且移动友好的 Vision Transformer”中提出的 MobileViT 架构如下: MobileViT 架构

其中一些内容应该看起来与前一章类似。MobileNet 块、nxn 卷积、下采样、全局池化和最终线性层。

如全局池化层和线性层所示,此处展示的模型用于分类。但是,本文介绍的相同模块可用于各种视觉应用。

MobileViT 块

MobileViT 块结合了 CNN 的局部处理和全局处理,如 Transformer 中所见。它结合了卷积和 Transformer 层,使其能够捕获数据中的空间局部信息和全局依赖关系。

MobileViT 块的图示如下: MobileViT 块

好的,信息量有点大。让我们分解一下。

  • 该块接收具有多个通道的图像。假设对于 RGB 图像有 3 个通道,因此该块接收一个三通道图像。
  • 然后,它对通道执行 N x N 卷积,并将它们附加到现有通道。
  • 然后,该块创建这些通道的线性组合,并将它们添加到现有的通道堆栈中。
  • 对于每个通道,这些图像被展开成扁平的图块。
  • 然后,这些扁平的图块通过 Transformer 进行投影,形成新的图块。
  • 然后,这些图块被折叠回一起,以创建具有 d 维度的图像。
  • 之后,在拼接图像上覆盖一个逐点卷积。
  • 然后,拼接图像与原始 RGB 图像重新组合。

这种方法允许 H x W(整个输入大小)的感受野,同时通过保留图块位置信息来建模非局部依赖关系和局部依赖关系。这可以在图块的展开和折叠中看到。

感受野是输入空间中影响特定层特征的区域大小。

这种复合方法使 MobileViT 比传统 CNN 具有更少的参数,甚至更好的精度! MobileViT CNN 性能

原始 MobileViT 架构中的主要效率瓶颈是 Transformer 中的多头自注意力,它对输入令牌的时间复杂度为 O(k^2)。

多头自注意力还需要昂贵的操作,如批量矩阵乘法,这会影响资源受限设备上的延迟。

这些相同的作者撰写了另一篇论文,专门介绍如何使注意力机制运行得更快。他们称之为可分离自注意力。

可分离自注意力

在传统的多头注意力中,关于输入令牌的大 O 是二次方 (O(k^2))。本文介绍的可分离自注意力对输入令牌的复杂度为 O(k)。

此外,注意力方法不使用任何批量矩阵乘法,这有助于减少资源受限设备(如手机)上的延迟。

这是一个巨大的改进!

还有许多其他形式的注意力机制,其复杂度范围从 O(k)、O(k*sqrt(k))、O(k*log(k))。

可分离自注意力并不是第一个具有 O(k) 复杂度的论文。在 Linformer 中,可分离自注意力之前也实现了注意力机制的 O(k) 复杂度。

然而,它仍然使用了昂贵的操作,如批量矩阵乘法。

Transformer、Linformer 和 MobileViT 中注意力机制的比较如下所示: 注意力机制比较

上图给出了 Transformer、Linformer 和 MobileViT v2 架构之间每种个体类型的注意力机制的比较。

例如,在 Transformer 和 Linformer 架构中,注意力计算都执行两次批量矩阵乘法。

然而,在可分离自注意力的情况下,这两个批量乘法被两个单独的线性计算所取代。这允许进一步提高推理速度。

结论

MobileViT 块在开发全局表示的同时保留空间局部信息,结合了 Transformer 和 CNN 的优势。它们提供了一个包含整个图像的感受野。

将可分离自注意力引入到现有架构中,甚至进一步提高了准确性和推理速度。 推理测试

在 iPhone 12s 上使用不同架构进行的测试表明,引入可分离注意力后,性能大幅提升,如上所示!

总的来说,MobileViT 架构是一种非常强大的架构,适用于资源受限的视觉任务,可提供快速的推理和高精度。

Transformers 库

如果您想在本地试用 MobileViTv2,可以使用 HuggingFace 的 transformers 库,方法如下

pip install transformers

以下是关于如何使用 MobileViT 模型对图像进行分类的简短代码片段。

from transformers import AutoImageProcessor, MobileViTV2ForImageClassification
from datasets import load_dataset
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained(
    "apple/mobilevitv2-1.0-imagenet1k-256"
)
model = MobileViTV2ForImageClassification.from_pretrained(
    "apple/mobilevitv2-1.0-imagenet1k-256"
)

inputs = image_processor(image, return_tensors="pt")

logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

推理 API

对于更轻量级的计算机视觉设置,您可以使用 Hugging Face 推理 API 和 MobileViTv2。推理 API 是一个与 Hugging Face Hub 上提供的许多模型交互的 API。我们可以通过 Python 像下面这样查询推理 API。

import json
import requests

headers = {"Authorization": f"Bearer {API_TOKEN}"}
API_URL = (
    "https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256"
)


def query(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.request("POST", API_URL, headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))


data = query("cats.jpg")

我们可以用 javascript 做同样的事情,如下所示。

import fetch from "node-fetch";
import fs from "fs";
async function query(filename) {
    const data = fs.readFileSync(filename);
    const response = await fetch(
        "https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256",
        {
            headers: { Authorization: `Bearer ${API_TOKEN}` },
            method: "POST",
            body: data,
        }
    );
    const result = await response.json();
    return result;
}
query("cats.jpg").then((response) => {
    console.log(JSON.stringify(response));
});

最后,我们可以通过 curl 查询推理 API。

curl https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256 \
        -X POST \
        --data-binary '@cats.jpg' \
        -H "Authorization: Bearer ${HF_API_TOKEN}"
< > 更新 在 GitHub 上