视觉Transformer (ViT) 微调过程中嵌入的变化

TL;DR
微调显著影响图像分类中的嵌入。微调前嵌入提供通用表示,而微调后嵌入捕获任务特定特征。这种区别可能导致异常值检测和其他任务的结果不同。微调前和微调后嵌入都有其独特的优势,应结合使用以实现图像分类和分析任务的全面分析。
查看本文 CIFAR-10 数据集 [3] 的在线演示之一
1 引言
在大型数据集(如ImageNet)上使用预训练模型,然后针对特定目标数据集进行微调,已成为图像分类的默认方法。然而,在处理现实世界目标数据集时,必须考虑其固有的噪声,包括异常值、标签错误和其他异常情况。数据集的交互式探索在全面理解数据方面发挥着关键作用,通过利用数据丰富,可以识别和解决关键数据段。
嵌入在分析非结构化图像数据中扮演着关键角色。它们提供高级语义信息,并支持各种任务,如数据分析、洞察生成和异常值检测。通过在较低维度空间中表示图像,嵌入使得探索数据内部的相似性和差异变得更加容易,并允许使用t-SNE 或 UMAP 等技术创建相似性图。我们将使用 GitHub 上可用的 Renumics Spotlight (github.com/Renumics/spotlight) 来交互式探索我们创建的丰富数据集。

免责声明:本文作者也是Spotlight的开发者之一。本文中的部分代码片段也可在Spotlight仓库中找到。
本文将深入探讨微调前和微调后嵌入的差异,并额外关注异常值检测。虽然需要注意的是,使用微调模型中的嵌入可能并不总是能为异常值检测带来最佳结果,因为我们也可以使用概率,但这仍然是一种引人入胜的方法。嵌入的可视化为分析过程增添了视觉吸引力。
为了评估嵌入在异常值检测任务中的性能和有效性,我们将检查图像分类中广泛使用的示例数据集。此外,我们将利用两个常见的通用模型。通过这次探索,我们旨在深入了解模型微调对嵌入的影响,从而更好地理解其能力和局限性。
2 准备工作
安装所需的Python包
!pip install renumics-spotlight datasets torch pandas cleanlab annoy
2.1 提取嵌入
我们将使用基于 google/vit-base-patch16–224-in21k [1] 和 microsoft/swin-base-patch4-window7–224 [2] 的 Hugging Faces 模型来提取微调前嵌入以及每个数据集最受欢迎的微调模型:araki/vit-base-patch16–224-in21k-finetuned-cifar10、MazenAmria/swin-tiny-finetuned-cifar100、nateraw/vit-base-beans、farleyknight/mnist-digit-classification-2022–09–04。
case = {
"cifar10": {
"base_model_name": "google/vit-base-patch16-224-in21k",
"ft_model_name": "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
},
"beans": {
"base_model_name": "google/vit-base-patch16-224-in21k",
"ft_model_name": "nateraw/vit-base-beans",
},
"mnist": {
"base_model_name": "google/vit-base-patch16-224-in21k",
"ft_model_name": "farleyknight/mnist-digit-classification-2022-09-04",
},
"cifar100": {
"base_model_name": "microsoft/swin-base-patch4-window7-224",
"ft_model_name": "MazenAmria/swin-tiny-finetuned-cifar100",
},
}
为了加载数据集,我们使用datasets模块中的load_dataset函数,并为图像分类任务做好准备。您可以选择本文中经过测试和报告的数据集:CIFAR-10 [3]、CIFAR-100 [3]、MNIST [4] 和 Beans [5],或者尝试来自Hugging Face的不同图像分类数据集以及相应的模型。
import datasets
# choose from cifar10, cifar100, mnist or beans.
# corresponding model will be selected automatically
DATASET = "cifar10"
ds = datasets.load_dataset(DATASET, split="train").prepare_for_task(
"image-classification"
)
df = ds.to_pandas()
# df = df.iloc[:1000] # uncomment to limit the dataset size for testing
我们定义 `huggingface_embedding` 函数,用于从微调模型和基础/基础模型中提取嵌入。嵌入存储在原始数据帧 (df) 中的独立列("embedding_ft" 和 "embedding_foundation")中。
import datasets
from transformers import AutoFeatureExtractor, AutoModel
import torch
import pandas as pd
ft_model_name = case[DATASET]["ft_model_name"]
base_model_name = case[DATASET]["base_model_name"]
def extract_embeddings(model, feature_extractor, image_name="image"):
"""
Utility to compute embeddings.
Args:
model: huggingface model
feature_extractor: huggingface feature extractor
image_name: name of the image column in the dataset
Returns:
function to compute embeddings
"""
device = model.device
def pp(batch):
images = batch[image_name]
inputs = feature_extractor(
images=[x.convert("RGB") for x in images], return_tensors="pt"
).to(device)
embeddings = model(**inputs).last_hidden_state[:, 0].cpu()
return {"embedding": embeddings}
return pp
def huggingface_embedding(
df,
image_name="image",
modelname="google/vit-base-patch16-224",
batched=True,
batch_size=24,
):
"""
Compute embeddings using huggingface models.
Args:
df: dataframe with images
image_name: name of the image column in the dataset
modelname: huggingface model name
batched: whether to compute embeddings in batches
batch_size: batch size
Returns:
new dataframe with embeddings
"""
# initialize huggingface model
feature_extractor = AutoFeatureExtractor.from_pretrained(modelname)
model = AutoModel.from_pretrained(modelname, output_hidden_states=True)
# create huggingface dataset from df
dataset = datasets.Dataset.from_pandas(df).cast_column(image_name, datasets.Image())
# compute embedding
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device), feature_extractor, image_name)
updated_dataset = dataset.map(extract_fn, batched=batched, batch_size=batch_size)
df_temp = updated_dataset.to_pandas()
df_emb = pd.DataFrame()
df_emb["embedding"] = df_temp["embedding"]
return df_emb
embeddings_df = huggingface_embedding(
df,
modelname=ft_model_name,
batched=True,
batch_size=24,
)
embeddings_df_found = huggingface_embedding(
df, modelname=base_model_name, batched=True, batch_size=24
)
df["embedding_ft"] = embeddings_df["embedding"]
df["embedding_foundation"] = embeddings_df_found["embedding"]
2.2 计算异常值分数
接下来,我们使用Cleanlab来计算微调模型和基于嵌入的基础模型的异常值分数。我们利用`OutOfDistribution`类来计算异常值分数。计算出的异常值分数存储在原始数据框(df)中。
from cleanlab.outlier import OutOfDistribution
import numpy as np
import pandas as pd
def outlier_score_by_embeddings_cleanlab(df, embedding_name="embedding"):
"""
Calculate outlier score by embeddings using cleanlab
Args:
df: dataframe with embeddings
embedding_name: name of the column with embeddings
Returns:
new df_out: dataframe with outlier score
"""
embs = np.stack(df[embedding_name].to_numpy())
ood = OutOfDistribution()
ood_train_feature_scores = ood.fit_score(features=np.stack(embs))
df_out = pd.DataFrame()
df_out["outlier_score_embedding"] = ood_train_feature_scores
return df_out
df["outlier_score_ft"] = outlier_score_by_embeddings_cleanlab(
df, embedding_name="embedding_ft"
)["outlier_score_embedding"]
df["outlier_score_found"] = outlier_score_by_embeddings_cleanlab(
df, embedding_name="embedding_foundation"
)["outlier_score_embedding"]
2.3 查找最近邻
为了评估异常值,我们仅使用微调模型计算最近邻图像,使用 Annoy 库。结果图像存储在原始 DataFrame (df) 中。
from annoy import AnnoyIndex
import pandas as pd
def nearest_neighbor_annoy(
df, embedding_name="embedding", threshold=0.3, tree_size=100
):
"""
Find nearest neighbor using annoy.
Args:
df: dataframe with embeddings
embedding_name: name of the embedding column
threshold: threshold for outlier detection
tree_size: tree size for annoy
Returns:
new dataframe with nearest neighbor information
"""
embs = df[embedding_name]
t = AnnoyIndex(len(embs[0]), "angular")
for idx, x in enumerate(embs):
t.add_item(idx, x)
t.build(tree_size)
images = df["image"]
df_nn = pd.DataFrame()
nn_id = [t.get_nns_by_item(i, 2)[1] for i in range(len(embs))]
df_nn["nn_id"] = nn_id
df_nn["nn_image"] = [images[i] for i in nn_id]
df_nn["nn_distance"] = [t.get_distance(i, nn_id[i]) for i in range(len(embs))]
df_nn["nn_flag"] = df_nn.nn_distance < threshold
return df_nn
df_nn = nearest_neighbor_annoy(
df, embedding_name="embedding_ft", threshold=0.3, tree_size=100
)
df["nn_image"] = df_nn["nn_image"]
2.4 可视化
为了使用 Spotlight 进行可视化,通过使用 lambda 函数将整数标签映射到其字符串表示,在 DataFrame 中创建了一个新的“label_str”列。`dtypes` 字典用于指定每列的数据类型以获得正确的可视化,而 `layout` 决定了可视化中的排列和显示的列。
from renumics import spotlight
df["label_str"] = df["labels"].apply(lambda x: ds.features["labels"].int2str(x))
dtypes = {
"nn_image": spotlight.Image,
"image": spotlight.Image,
"embedding_ft": spotlight.Embedding,
"embedding_foundation": spotlight.Embedding,
}
spotlight.show(
df,
dtype=dtypes,
layout="https://spotlight.renumics.com/resources//layout_pre_post_ft.json",
)
这将打开一个新浏览器窗口

在可视化部分,左上角显示了一个综合表格,其中包含数据集中所有存在的字段。通过基础模型嵌入分类为异常值的图像被选中。在右上角,您可以看到两个 UMAP 表示:第一个表示由基础模型生成的嵌入,第二个表示由微调模型生成的嵌入。在底部,选定的图像及其在数据集中的最近邻一起显示。
3 结果
现在让我们检查所有数据集的结果。您可以按照第2节的所有步骤,使用不同的输入数据集来重现结果,或者使用下面的代码片段加载预处理数据集。或者您也可以查看链接的在线演示。
3.1 CIFAR-10
加载已准备好的CIFAR-10数据集 [3],并
from renumics import spotlight
import datasets
ds = datasets.load_dataset("renumics/cifar10-outlier", split="train")
df = ds.rename_columns({"img": "image", "label": "labels"}).to_pandas()
df["label_str"] = df["labels"].apply(lambda x: ds.features["label"].int2str(x))
dtypes = {
"nn_image": spotlight.Image,
"image": spotlight.Image,
"embedding_ft": spotlight.Embedding,
"embedding_foundation": spotlight.Embedding,
}
spotlight.show(
df,
dtype=dtypes,
layout="https://spotlight.renumics.com/resources/layout_pre_post_ft.json",
)
或访问在线演示 https://huggingface.co/spaces/renumics/cifar10-outlier 查看异常值


微调后嵌入的 UMAP 可视化揭示了清晰的模式,其中某些类别完全与其他所有类别分离,而某些类别可能仅与一个或两个其他类别连接。
使用微调前嵌入在 CIFAR-10 中检测到的异常值似乎并不特别罕见,因为它们具有相对相似的邻近图像。相比之下,使用微调后嵌入识别出的异常值在数据集中是独特且非常罕见的。
3.2 CIFAR-100
加载已准备好的CIFAR-100数据集 [3],并
from renumics import spotlight
import datasets
ds = datasets.load_dataset("renumics/cifar100-outlier", split="train")
df = ds.rename_columns({"img": "image", "fine_label": "labels"}).to_pandas()
df["label_str"] = df["labels"].apply(lambda x: ds.features["fine_label"].int2str(x))
dtypes = {
"nn_image": spotlight.Image,
"image": spotlight.Image,
"embedding_ft": spotlight.Embedding,
"embedding_foundation": spotlight.Embedding,
}
spotlight.show(
df,
dtype=dtypes,
layout="https://spotlight.renumics.com/resources/layout_pre_post_ft.json",
)
或访问在线演示 huggingface.co/spaces/renumics/cifar100-outlier 查看异常值


在检查由100个类别组成的CIFAR-100的嵌入时,我们观察到即使在微调之后,与微调前嵌入相比,仍有更多相互连接的类别。然而,嵌入空间内的结构变得明显更清晰和更有组织。
预微调嵌入未显示出与相邻图像明显不同的异常值,表明在异常值检测方面的效果有限。然而,当使用微调后嵌入时,性能有所提高。在识别出的六个异常值中,前三个被有效地检测为数据集中不常见的。
3.3 MNIST
加载已准备好的MNIST数据集 [4],并
from renumics import spotlight
import datasets
ds = datasets.load_dataset("renumics/mnist-outlier", split="train")
df = ds.rename_columns({"label": "labels"}).to_pandas()
df["label_str"] = df["labels"].apply(lambda x: ds.features["label"].int2str(x))
dtypes = {
"nn_image": spotlight.Image,
"image": spotlight.Image,
"embedding_ft": spotlight.Embedding,
"embedding_foundation": spotlight.Embedding,
}
spotlight.show(
df,
dtype=dtypes,
layout="https://spotlight.renumics.com/resources/layout_pre_post_ft.json",
)
或访问在线演示 huggingface.co/spaces/renumics/mnist-outlier 查看异常值


在 MNIST 微调过程中,嵌入发生了显著变化。微调前,不同数字类别之间可能存在重叠区域,使得仅基于嵌入邻近度难以区分它们。然而,微调后,嵌入在数字类别之间表现出更清晰的分离。
微调前嵌入仅显示一个与相邻图像明显不同的异常值,表明异常值检测性能中等。然而,当使用微调后嵌入时,异常值检测有所改善。大约可以识别出3到4个异常值在数据集中非常罕见。
3.4 豆类
加载已准备好的豆类数据集 [3],并
from renumics import spotlight
import datasets
ds = datasets.load_dataset("renumics/beans-outlier", split="train")
df = ds.to_pandas()
df["label_str"] = df["labels"].apply(lambda x: ds.features["labels"].int2str(x))
dtypes = {
"nn_image": spotlight.Image,
"image": spotlight.Image,
"embedding_ft": spotlight.Embedding,
"embedding_foundation": spotlight.Embedding,
}
spotlight.show(
df,
dtype=dtypes,
layout="https://spotlight.renumics.com/resources/layout_pre_post_ft.json",
)
或访问在线演示 huggingface.co/spaces/renumics/beans-outlier 查看异常值


在豆类数据集中,微调后,大多数嵌入在三个类别之间显示出完全分离。然而,少数情况下仍然存在轻微重叠,这可能是由于某些豆类类型之间的相似性或分类错误造成的。
使用微调前和微调后嵌入的异常值检测并未产生明显偏离正常值的异常值。识别出的异常值在数据集中既不独特也不罕见。
4 结论
总而言之,微调对图像分类中的嵌入具有显著影响。微调前,嵌入提供通用表示,而微调后,它们捕获任务特定的特征。
这种区别在 UMAP 可视化中清晰地体现出来,其中微调后的嵌入呈现出更具结构化的模式,某些类别与其他类别完全分离。
对于异常值检测,使用微调后嵌入可能更有效。然而,值得注意的是,根据微调获得的概率计算异常值,可能比单纯依赖嵌入产生更好的结果。
微调前和微调后嵌入都具有其独特的优势,应结合使用以实现图像分类和分析任务的全面分析。
参考文献
[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 一张图片胜过16x16个词:用于大规模图像识别的Transformer (2020), arXiv
[2] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo Swin Transformer:使用滑动窗口的层次化视觉Transformer (2021), arXiv
[3] Alex Krizhevsky, 从微小图像中学习多层特征 (2009), 多伦多大学
[4] Yann LeCun, Corinna Cortes, Christopher J.C. Burges, MNIST手写数字数据库 (2010), ATT Labs [在线]
[5] Makerere AI Lab, 豆类疾病数据集 (2020), AIR Lab Makerere University