Transformers 文档

计算机视觉中的知识蒸馏

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

计算机视觉知识蒸馏

知识蒸馏是一种将知识从更大、更复杂的模型(教师)转移到更小、更简单的模型(学生)的技术。为了将一个模型的知识蒸馏到另一个模型,我们采用一个在特定任务(在本例中为图像分类)上训练过的预训练教师模型,并随机初始化一个用于图像分类训练的学生模型。接下来,我们训练学生模型以最小化其输出与教师输出之间的差异,从而使其模仿教师的行为。它首次在 Hinton 等人撰写的《神经网络中的知识蒸馏》 中被提出。在本指南中,我们将进行特定于任务的知识蒸馏。我们将使用 beans 数据集 来完成此操作。

本指南演示了如何使用 🤗 Transformers 的 Trainer API 将一个 微调的 ViT 模型(教师模型)蒸馏到一个 MobileNet(学生模型)。

让我们安装蒸馏和评估过程中所需的库。

pip install transformers datasets accelerate tensorboard evaluate --upgrade

在本例中,我们使用 merve/beans-vit-224 模型作为教师模型。它是一个基于 google/vit-base-patch16-224-in21k 在 beans 数据集上微调的图像分类模型。我们将把这个模型蒸馏到一个随机初始化的 MobileNetV2。

我们现在将加载数据集。

from datasets import load_dataset

dataset = load_dataset("beans")

我们可以使用任何一个模型的图像处理器,因为在本例中,它们使用相同的分辨率返回相同的输出。我们将使用 datasetmap() 方法将预处理应用于数据集的每个拆分。

from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

processed_datasets = dataset.map(process, batched=True)

从本质上讲,我们希望学生模型(一个随机初始化的 MobileNet)模仿教师模型(微调的视觉转换器)。为了实现这一点,我们首先获取教师和学生模型的 logits 输出。然后,我们将每个输出除以参数 temperature,该参数控制每个软目标的重要性。一个名为 lambda 的参数权衡了蒸馏损失的重要性。在本例中,我们将使用 temperature=5lambda=0.5。我们将使用 Kullback-Leibler 散度损失来计算学生和教师之间的散度。给定两个数据 P 和 Q,KL 散度解释了使用 Q 表示 P 需要多少额外信息。如果两者相同,则它们的 KL 散度为零,因为不需要其他信息来从 Q 解释 P。因此,在知识蒸馏的背景下,KL 散度很有用。

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F


class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

我们现在将登录 Hugging Face Hub,以便我们能够通过 Trainer 将我们的模型推送到 Hugging Face Hub。

from huggingface_hub import notebook_login

notebook_login()

让我们设置 TrainingArguments、教师模型和学生模型。

from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

training_args = TrainingArguments(
    output_dir="my-awesome-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

我们可以使用 compute_metrics 函数来评估我们模型在测试集上的表现。此函数将在训练过程中用于计算模型的 accuracyf1

import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

让我们使用我们定义的训练参数初始化 Trainer。我们还将初始化我们的数据整理器。

from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    training_args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    tokenizer=teacher_processor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)

现在我们可以训练我们的模型了。

trainer.train()

我们可以评估模型在测试集上的表现。

trainer.evaluate(processed_datasets["test"])

在测试集上,我们的模型达到了 72% 的准确率。为了对蒸馏的效率进行合理性检查,我们还在 beans 数据集上从头开始训练了 MobileNet,使用相同的超参数,并在测试集上观察到 63% 的准确率。我们邀请读者尝试不同的预训练教师模型、学生架构、蒸馏参数并报告他们的发现。蒸馏模型的训练日志和检查点可以在 此存储库 中找到,从头开始训练的 MobileNetV2 可以在 此存储库 中找到。

< > 在 GitHub 上更新