使用 🤗 Transformers 微调 ViT 进行图像分类

发布于 2022 年 2 月 11 日
在 GitHub 上更新
Open In Colab

正如基于 transformer 的模型彻底改变了自然语言处理一样,我们现在也看到大量论文将它们应用于各种其他领域。其中最具革命性之一的是 Vision Transformer (ViT),它于 2021 年 6 月由 Google Brain 的研究团队推出。

这篇论文探讨了如何像标记句子一样标记图像,以便将它们传递给 transformer 模型进行训练。这真的非常简单……

  1. 将图像分割成子图像块网格
  2. 使用线性投影嵌入每个图像块
  3. 每个嵌入的图像块都成为一个标记,由此产生的嵌入图像块序列就是您传递给模型的序列。

事实证明,完成上述操作后,您可以像在自然语言处理任务中一样预训练和微调 transformer。非常棒 😎。


在这篇博文中,我们将逐步介绍如何利用 🤗 datasets 下载和处理图像分类数据集,然后使用它们通过 🤗 transformers 微调预训练的 ViT。

首先,让我们先安装这两个软件包。

pip install datasets transformers

加载数据集

让我们从加载一个小型图像分类数据集并查看其结构开始。

我们将使用 beans 数据集,它是健康和不健康豆叶图片的集合。🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

让我们看看 beans 数据集 'train' 分割中的第 400 个示例。您会注意到数据集中的每个示例都有 3 个特征

  1. image:PIL 图像
  2. image_file_path:作为 image 加载的图像文件的 str 路径
  3. labels:一个 datasets.ClassLabel 特征,它是标签的整数表示。(稍后您会看到如何获取字符串类名,别担心!)
ex = ds['train'][400]
ex
{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

让我们看看图像 👀

image = ex['image']
image

这绝对是一片叶子!但是是什么种类呢?😅

由于此数据集的 'labels' 特征是 datasets.features.ClassLabel,我们可以使用它来查找此示例标签 ID 对应的名称。

首先,让我们访问 'labels' 的特征定义。

labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

现在,让我们打印出我们示例的类标签。您可以使用 ClassLabelint2str 函数来完成此操作,顾名思义,它允许传递类的整数表示来查找字符串标签。

labels.int2str(ex['labels'])
'bean_rust'

原来上面显示的叶子感染了豆锈病,这是一种严重的豆类植物病害。😢

让我们编写一个函数,它将显示每个类别的一些示例网格,以便更好地了解您正在处理的内容。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
数据集中每个类别的一些示例网格

据我所知,

  • 角斑病:有不规则的棕色斑块
  • 豆锈病:有圆形棕色斑点,周围有白色-黄色环
  • 健康:……看起来很健康。🤷‍♂️

加载 ViT 图像处理器

现在我们知道图像是什么样子以及我们正在努力解决的问题。让我们看看如何为模型准备这些图像!

当 ViT 模型进行训练时,会对其输入的图像应用特定的转换。如果对图像使用了错误的转换,模型将无法理解它所看到的内容!🖼 ➡️ 🔢

为了确保我们应用正确的转换,我们将使用一个 ViTImageProcessor,它使用我们计划使用的预训练模型保存的配置进行初始化。在本例中,我们将使用 google/vit-base-patch16-224-in21k 模型,因此让我们从 Hugging Face Hub 加载其图像处理器。

from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

您可以通过打印图像处理器配置来查看它。

ViTImageProcessor {
  "do_normalize": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

要处理图像,只需将其传递给图像处理器的调用函数。这将返回一个包含 pixel values 的字典,这是要传递给模型的数字表示。

默认情况下您会得到一个 NumPy 数组,但如果您添加 return_tensors='pt' 参数,您将得到 torch 张量。

processor(image, return_tensors='pt')

应该会得到类似以下内容...

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

...其中张量的形状为 (1, 3, 224, 224)

处理数据集

现在您已经知道如何读取图像并将其转换为输入,让我们编写一个函数,将这两者结合起来,以处理数据集中的单个示例。

def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
process_example(ds['train'][0])
{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': 0
}

虽然您可以调用 ds.map 并将其一次性应用于每个示例,但这可能会非常慢,特别是如果您使用更大的数据集。相反,您可以对数据集应用一个 *转换*。转换仅在您索引示例时应用。

不过,首先,您需要更新最后一个函数以接受一批数据,因为这就是 ds.with_transform 所期望的。

ds = load_dataset('beans')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

您可以使用 ds.with_transform(transform) 直接将其应用于数据集。

prepared_ds = ds.with_transform(transform)

现在,每当您从数据集中获取一个示例时,转换将实时应用(在样本和切片上,如下所示)

prepared_ds['train'][0:2]

这次,得到的 pixel_values 张量形状将是 (2, 3, 224, 224)

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

训练和评估

数据已处理完毕,您已准备好开始设置训练管道。这篇博文使用 🤗 的 Trainer,但这需要我们先做几件事

  • 定义一个 collate 函数。

  • 定义一个评估指标。在训练期间,模型应根据其预测准确性进行评估。您应相应地定义一个 compute_metrics 函数。

  • 加载预训练的检查点。您需要加载预训练的检查点并正确配置它以进行训练。

  • 定义训练配置。

微调模型后,您将正确评估评估数据上的模型,并验证它确实学会了正确分类图像。

定义我们的数据整理器

批次以字典列表的形式传入,因此您只需将它们解包并堆叠成批次张量即可。

由于 collate_fn 将返回一个批次字典,因此您可以稍后将输入 **解包 到模型中。✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

定义评估指标

evaluate准确度 指标可以轻松用于比较预测与标签。下面,您可以看到如何在 compute_metrics 函数中使用它,该函数将由 Trainer 使用。

import numpy as np
from evaluate import load

metric = load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

让我们加载预训练模型。我们将在初始化时添加 num_labels,以便模型创建一个具有正确单元数量的分类头。我们还将包含 id2labellabel2id 映射,以便在 Hub 小部件中拥有人类可读的标签(如果您选择 push_to_hub)。

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

差不多准备好训练了!在此之前,最后需要做的是通过定义 TrainingArguments 来设置训练配置。

其中大多数都非常直观,但其中一个非常重要的是 remove_unused_columns=False。这个参数将删除模型调用函数未使用的任何特征。默认情况下它为 True,因为通常最好删除未使用的特征列,这样可以更容易地将输入解包到模型的调用函数中。但是,在我们的例子中,我们需要未使用的特征(尤其是“image”)才能创建“pixel_values”。

我想说的是,如果你忘记设置 remove_unused_columns=False,你将会很糟糕。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

现在,所有实例都可以传递给 Trainer,我们准备开始训练了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)

训练 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

评估 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

这是我的评估结果——好棒的豆子!抱歉,我必须说出来。

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

最后,如果您愿意,可以将模型推送到 Hub。在这里,如果您在训练配置中指定了 push_to_hub=True,我们将将其推送到 Hub。请注意,要推送到 Hub,您必须安装 git-lfs 并登录您的 Hugging Face 帐户(可以通过 huggingface-cli login 完成)。

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

生成的模型已共享至 nateraw/vit-base-beans。我假设您手头没有豆叶图片,所以我添加了一些示例供您尝试!🚀

社区

注册登录 发表评论