社区计算机视觉课程文档

图像分类的迁移学习和微调 Vision Transformer

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

图像分类的迁移学习和微调 Vision Transformer

简介

随着 Transformer 架构在自然语言处理领域的良好扩展,相同的架构也被应用于图像,通过创建图像的小块并将它们视为 tokens。结果就是 Vision Transformer (ViT)。在我们开始迁移学习/微调概念之前,让我们比较一下卷积神经网络 (CNN) 和 Vision Transformer。

Vision Transformer (VT) 概述

总而言之,在 Vision Transformer 中,图像被重组为 2D 网格的 patches。模型在这些 patches 上进行训练。

主要思想可以在下图找到: Vision Transformer

但是有一个问题!卷积神经网络 (CNN) 的设计带有一个 Vision Transformer 中缺失的假设。这个假设是基于我们人类如何感知图像中的对象。这将在以下部分中描述。

CNN 和 Vision Transformer 之间有什么区别?

归纳偏置

归纳偏置是机器学习中用于描述学习算法用于进行预测的一组假设的术语。简单来说,归纳偏置就像一个捷径,可以帮助机器学习模型根据它目前看到的信息做出有根据的猜测。

以下是我们在 CNN 中观察到的一些归纳偏置

  • 平移等变性:对象可以出现在图像中的任何位置,CNN 可以检测到它的特征。
  • 局部性:图像中的像素主要与其周围的像素交互以形成特征。

CNN 模型非常擅长这两种偏置。ViT 没有这种假设。这就是为什么对于数据集大小达到一定阈值之前,实际上 CNN 比 ViT 更好。但是 ViT 有另一种力量!Transformer 架构(主要是)不同类型的线性函数允许 ViT 变得高度可扩展。反过来,这使得 ViT 能够通过大量数据克服没有上述两种归纳偏置的问题!

但是,每个人如何才能访问海量数据集?

对于每个人来说,在数百万张图像上训练 Vision Transformer 以获得良好的性能是不可行的。相反,可以使用来自 Hugging Face Hub 等地方的公开模型权重。

如何处理预训练模型?您可以应用迁移学习并对其进行微调!

图像分类的迁移学习和微调

迁移学习的思想是,我们可以利用在非常大的数据集上训练的 Vision Transformer 学习到的特征,并将这些特征应用于我们的数据集。这可以显着提高模型性能,尤其是在我们的数据集可用于训练的数据有限时。

由于我们正在利用学习到的特征,因此我们也不需要更新整个模型。通过冻结大部分权重,我们可以仅训练某些层,以在更少的训练时间和低 GPU 消耗下获得出色的性能。

多类别图像分类

您可以通过本 notebook 中使用 Vision Transformer 进行图像分类的迁移学习教程

Open In Colab

这就是我们将要构建的内容:一个图像分类器,用于区分狗和猫的品种


您的数据集的领域可能与预训练模型的数据集非常不同。然而,与其从头开始训练 Vision Transformer,我们可以选择更新整个预训练模型的权重,尽管学习率较低,这将“微调”模型以使其在我们的数据上表现良好。

但是,在大多数情况下,在 Vision Transformer 的情况下,应用迁移学习就足够了。

多标签图像分类

上面的教程教授了多类别图像分类,其中每张图像仅分配有 1 个类别。如果每个图像在多类别数据集中都有多个标签,该怎么办?

本 notebook 将引导您完成使用 Vision Transformer 进行多标签图像分类的微调教程

Open In Colab

我们还将学习如何使用 Hugging Face Accelerate 来编写我们的自定义训练循环。这就是您可以期望看到的作为多标签分类教程的结果


其他资源

  • 原始 Vision Transformer 论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale 论文
  • Swin Transformer 论文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 论文
  • 为了更好地理解 Vision Transformer 的训练数据量、正则化、增强、模型大小和计算预算之间相互作用的系统实证研究:How to train your Vision Transformers? Data, Augmentation, and Regularization in Vision Transformers 论文
< > 在 GitHub 上更新