MobileViT v2
之前讨论的视觉Transformer架构在计算上很密集,很难在移动设备上运行。以前的移动视觉任务最先进的架构使用 CNN。但是,CNN 无法学习全局表示,因此它们的表现不如其Transformer 对应物。
MobileViT 架构旨在解决移动视觉任务所需的各种问题,例如低延迟和轻量级架构,同时提供Transformer 和 CNN 的优势。MobileViT 架构由苹果公司开发,并基于来自谷歌研究团队的 MobileNet。MobileViT 架构在之前的 MobileNet 架构基础上构建,通过添加 MobileViT 块和可分离自注意力。这两个特性允许快速实现闪电般快速的延迟,减少参数、计算复杂度,并在资源受限的设备上部署视觉 ML 模型。
MobileViT 架构
Sachin Mehta 和 Mohammad Rastegari 在论文 “MobileViT:轻量级、通用且移动友好的视觉Transformer” 中提出的 MobileViT 架构如下:
其中一些应该看起来类似于上一章。MobileNet 块、nxn 卷积、下采样、全局池化以及最后的线性层。
从全局池化层和线性层可以看出,这里显示的模型用于分类。但是,本文介绍的相同块可用于各种视觉应用。
MobileViT 块
MobileViT 块结合了 CNN 的局部处理和 Transformer 中的全局处理。它使用卷积和 Transformer 层的组合,使其能够捕获数据中的空间局部信息和全局依赖关系。
以下是 MobileViT 块的示意图:
好吧,这有点难以理解。让我们分解一下。
- 该块接受具有多个通道的图像。让我们假设对于 RGB 图像,有 3 个通道,因此该块接受一个三通道图像。
- 然后,它对通道执行 N×N 卷积,并将它们附加到现有通道中。
- 然后,该块创建这些通道的线性组合,并将它们添加到现有通道堆栈中。
- 对于每个通道,这些图像被展开成扁平化的补丁。
- 然后,这些扁平化的补丁通过 Transformer 传递,以将它们投影到新的补丁中。
- 然后将这些补丁折叠回一起,以创建具有 d 维的图像。
- 之后,将逐点卷积叠加在拼接的图像上。
- 然后,将拼接的图像与原始 RGB 图像重新组合。
这种方法允许在保留补丁位置信息的同时,通过保留补丁位置信息来建模非局部依赖关系和局部依赖关系,从而实现 H×W(整个输入大小)的感受野。这可以在补丁的展开和折叠中看到。
这种复合方法使 MobileViT 能够拥有比传统 CNN 更少的参数,甚至具有更好的精度!
原始 MobileViT 架构中的主要效率瓶颈是 Transformer 中的多头自注意力,它需要关于输入令牌的 O(k^2) 时间复杂度。
多头自注意力还需要进行代价高昂的操作,如批次矩阵乘法,这会影响资源受限设备的延迟。
这些作者在另一篇论文中详细介绍了如何使注意力更快地运行。他们将其称为可分离自注意力。
可分离自注意力
在传统的多头注意力中,关于输入令牌的 Big O 是二次的(O(k^2))。本文中介绍的可分离自注意力关于输入令牌的复杂度为 O(k)。
此外,该注意力方法不使用任何批次矩阵乘法,这有助于减少移动电话等资源受限设备的延迟。
这是一个巨大的进步!
可分离自注意力不是第一个复杂度为 O(k) 的论文。在可分离自注意力之前,Linformer 中也实现了关于注意力的 O(k) 复杂度。
但是,它仍然使用代价高昂的操作,如批次矩阵乘法。
下面展示了 Transformer、Linformer 和 MobileViT 中注意力机制的比较:
上图比较了 Transformer、Linformer 和 MobileViT v2 架构中每种单独的注意力类型。
例如,在 Transformer 和 Linformer 架构中,注意力计算执行两次批次级矩阵乘法。
但是,在可分离自注意力的情况下,这两个批次级乘法被两个单独的线性计算所取代。这使得推断速度进一步提高。
结论
MobileViT 块保留空间局部信息,同时发展全局表示,结合了 Transformer 和 CNN 的优势。它们提供了一个涵盖整个图像的感受野。
将可分离自注意力引入到现有的架构中,进一步提高了准确性和推断速度。
如上所示,在 iPhone 12s 上使用不同架构进行的测试表明,引入可分离注意力后,性能大幅提升!
总的来说,MobileViT 架构对于资源受限的视觉任务来说是一个非常强大的架构,它提供了快速推断和高精度。
Transformers 库
如果您想在本地试用 MobileViTv2,可以使用 HuggingFace 的 transformers
库,方法如下
pip install transformers
以下是如何使用 MobileViT 模型对图像进行分类的简短代码片段。
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification
from datasets import load_dataset
from PIL import Image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained(
"apple/mobilevitv2-1.0-imagenet1k-256"
)
model = MobileViTV2ForImageClassification.from_pretrained(
"apple/mobilevitv2-1.0-imagenet1k-256"
)
inputs = image_processor(image, return_tensors="pt")
logits = model(**inputs).logits
# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
推断 API
为了获得更轻量级的计算机视觉设置,您可以将 Hugging Face 推断 API 与 MobileViTv2 一起使用。推断 API 是一个用于与 Hugging Face Hub 上的许多模型交互的 API。我们可以通过 Python 以以下方式查询推断 API。
import json
import requests
headers = {"Authorization": f"Bearer {API_TOKEN}"}
API_URL = (
"https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256"
)
def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_URL, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))
data = query("cats.jpg")
我们可以使用 javascript 以以下方式执行相同的操作。
import fetch from "node-fetch";
import fs from "fs";
async function query(filename) {
const data = fs.readFileSync(filename);
const response = await fetch(
"https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256",
{
headers: { Authorization: `Bearer ${API_TOKEN}` },
method: "POST",
body: data,
}
);
const result = await response.json();
return result;
}
query("cats.jpg").then((response) => {
console.log(JSON.stringify(response));
});
最后,我们可以通过 curl 查询推断 API。
curl https://api-inference.huggingface.co/models/apple/mobilevitv2-1.0-imagenet1k-256 \
-X POST \
--data-binary '@cats.jpg' \
-H "Authorization: Bearer ${HF_API_TOKEN}"