社区计算机视觉课程文档

MobileViT v2

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

MobileViT v2

之前讨论的视觉Transformer架构在计算上很密集,很难在移动设备上运行。以前的移动视觉任务最先进的架构使用 CNN。但是,CNN 无法学习全局表示,因此它们的表现不如其Transformer 对应物。

MobileViT 架构旨在解决移动视觉任务所需的各种问题,例如低延迟和轻量级架构,同时提供Transformer 和 CNN 的优势。MobileViT 架构由苹果公司开发,并基于来自谷歌研究团队的 MobileNet。MobileViT 架构在之前的 MobileNet 架构基础上构建,通过添加 MobileViT 块和可分离自注意力。这两个特性允许快速实现闪电般快速的延迟,减少参数、计算复杂度,并在资源受限的设备上部署视觉 ML 模型。

MobileViT 架构

Sachin Mehta 和 Mohammad Rastegari 在论文 “MobileViT:轻量级、通用且移动友好的视觉Transformer” 中提出的 MobileViT 架构如下:MobileViT 架构

其中一些应该看起来类似于上一章。MobileNet 块、nxn 卷积、下采样、全局池化以及最后的线性层。

从全局池化层和线性层可以看出,这里显示的模型用于分类。但是,本文介绍的相同块可用于各种视觉应用。

MobileViT 块

MobileViT 块结合了 CNN 的局部处理和 Transformer 中的全局处理。它使用卷积和 Transformer 层的组合,使其能够捕获数据中的空间局部信息和全局依赖关系。

以下是 MobileViT 块的示意图:MobileViT 块

好吧,这有点难以理解。让我们分解一下。

  • 该块接受具有多个通道的图像。让我们假设对于 RGB 图像,有 3 个通道,因此该块接受一个三通道图像。
  • 然后,它对通道执行 N×N 卷积,并将它们附加到现有通道中。
  • 然后,该块创建这些通道的线性组合,并将它们添加到现有通道堆栈中。
  • 对于每个通道,这些图像被展开成扁平化的补丁。
  • 然后,这些扁平化的补丁通过 Transformer 传递,以将它们投影到新的补丁中。
  • 然后将这些补丁折叠回一起,以创建具有 d 维的图像。
  • 之后,将逐点卷积叠加在拼接的图像上。
  • 然后,将拼接的图像与原始 RGB 图像重新组合。

这种方法允许在保留补丁位置信息的同时,通过保留补丁位置信息来建模非局部依赖关系和局部依赖关系,从而实现 H×W(整个输入大小)的感受野。这可以在补丁的展开和折叠中看到。

感受野是输入空间中影响特定层特征的区域的大小。

这种复合方法使 MobileViT 能够拥有比传统 CNN 更少的参数,甚至具有更好的精度!MobileViT CNNPreformance

原始 MobileViT 架构中的主要效率瓶颈是 Transformer 中的多头自注意力,它需要关于输入令牌的 O(k^2) 时间复杂度。

多头自注意力还需要进行代价高昂的操作,如批次矩阵乘法,这会影响资源受限设备的延迟。

这些作者在另一篇论文中详细介绍了如何使注意力更快地运行。他们将其称为可分离自注意力。

可分离自注意力

在传统的多头注意力中,关于输入令牌的 Big O 是二次的(O(k^2))。本文中介绍的可分离自注意力关于输入令牌的复杂度为 O(k)。

此外,该注意力方法不使用任何批次矩阵乘法,这有助于减少移动电话等资源受限设备的延迟。

这是一个巨大的进步!

已经出现了许多其他形式的注意力,使得复杂度范围从 O(k)、O(k*sqrt(k)) 到 O(k*log(k))。

可分离自注意力不是第一个复杂度为 O(k) 的论文。在可分离自注意力之前,Linformer 中也实现了关于注意力的 O(k) 复杂度。

但是,它仍然使用代价高昂的操作,如批次矩阵乘法。

下面展示了 Transformer、Linformer 和 MobileViT 中注意力机制的比较:注意力比较

上图比较了 Transformer、Linformer 和 MobileViT v2 架构中每种单独的注意力类型。

例如,在 Transformer 和 Linformer 架构中,注意力计算执行两次批次级矩阵乘法。

但是,在可分离自注意力的情况下,这两个批次级乘法被两个单独的线性计算所取代。这使得推断速度进一步提高。

结论

MobileViT 块保留空间局部信息,同时发展全局表示,结合了 Transformer 和 CNN 的优势。它们提供了一个涵盖整个图像的感受野。

将可分离自注意力引入到现有的架构中,进一步提高了准确性和推断速度。推断测试

如上所示,在 iPhone 12s 上使用不同架构进行的测试表明,引入可分离注意力后,性能大幅提升!

总的来说,MobileViT 架构对于资源受限的视觉任务来说是一个非常强大的架构,它提供了快速推断和高精度。

Transformers 库

如果您想在本地试用 MobileViTv2,可以使用 HuggingFace 的 transformers 库,方法如下

pip install transformers

以下是如何使用 MobileViT 模型对图像进行分类的简短代码片段。

from transformers import AutoImageProcessor, MobileViTV2ForImageClassification
from datasets import load_dataset
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained(
    "apple/mobilevitv2-1.0-imagenet1k-256"
)
model = MobileViTV2ForImageClassification.from_pretrained(
    "apple/mobilevitv2-1.0-imagenet1k-256"
)

inputs = image_processor(image, return_tensors="pt")

logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

推断 API

为了获得更轻量级的计算机视觉设置,您可以将 Hugging Face 推断 API 与 MobileViTv2 一起使用。推断 API 是一个用于与 Hugging Face Hub 上的许多模型交互的 API。我们可以通过 Python 以以下方式查询推断 API。

import json
import requests

headers = {"Authorization": f"Bearer {API_TOKEN}"}
API_URL = (
    "https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256"
)


def query(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.request("POST", API_URL, headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))


data = query("cats.jpg")

我们可以使用 javascript 以以下方式执行相同的操作。

import fetch from "node-fetch";
import fs from "fs";
async function query(filename) {
    const data = fs.readFileSync(filename);
    const response = await fetch(
        "https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256",
        {
            headers: { Authorization: `Bearer ${API_TOKEN}` },
            method: "POST",
            body: data,
        }
    );
    const result = await response.json();
    return result;
}
query("cats.jpg").then((response) => {
    console.log(JSON.stringify(response));
});

最后,我们可以通过 curl 查询推断 API。

curl https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256 \
        -X POST \
        --data-binary '@cats.jpg' \
        -H "Authorization: Bearer ${HF_API_TOKEN}"
< > 更新 在 GitHub 上