使用 Hugging Face Datasets 和 Transformers 实现图像相似性
在这篇文章中,您将学习如何使用 🤗 Transformers 构建图像相似性系统。找出查询图像和潜在候选图像之间的相似性是信息检索系统(例如反向图像搜索)的一个重要用例。该系统试图回答的问题是:给定一个查询图像和一组候选图像,哪些图像与查询图像最相似。
我们将利用 🤗 datasets
库,因为它无缝支持并行处理,这在构建此系统时会派上用场。
尽管本文使用基于 ViT 的模型(nateraw/vit-base-beans
)和特定数据集(Beans),但它可以扩展到使用其他支持视觉模态的模型和其他图像数据集。您可以尝试的一些著名模型包括:
此外,本文中介绍的方法也有可能扩展到其他模态。
要研究完全可用的图像相似性系统,您可以参考开头链接的 Colab Notebook。
我们如何定义相似性?
为了构建这个系统,我们首先需要定义如何计算两幅图像之间的相似性。一种广泛流行的方法是计算给定图像的密集表示(嵌入),然后使用余弦相似度度量来确定两幅图像的相似程度。
在本文中,我们将使用“嵌入”来表示向量空间中的图像。这为我们提供了一种很好的方式,可以将图像的高维像素空间(例如 224 x 224 x 3)有意义地压缩到更低的维度(例如 768)。这样做主要优点是减少了后续步骤中的计算时间。

计算嵌入
为了从图像中计算嵌入,我们将使用一个视觉模型,该模型对如何在向量空间中表示输入图像有一定的理解。这种类型的模型也通常被称为图像编码器。
为了加载模型,我们利用 AutoModel
类。它为我们提供了一个接口,用于从 Hugging Face Hub 加载任何兼容的模型检查点。除了模型,我们还加载了与模型关联的处理器用于数据预处理。
from transformers import AutoImageProcessor, AutoModel
model_ckpt = "nateraw/vit-base-beans"
processor = AutoImageProcessor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
在这种情况下,检查点是通过在 beans
数据集上微调 基于 Vision Transformer 的模型获得的。
这里可能会出现一些问题
问题 1:为什么我们不使用 AutoModelForImageClassification
?
这是因为我们想要获得图像的密集表示,而不是离散类别,而这正是 AutoModelForImageClassification
所能提供的。
问题 2:为什么是这个特定的检查点?
如前所述,我们正在使用特定数据集来构建系统。因此,与其使用通用模型(例如在 ImageNet-1k 数据集上训练的模型),不如使用已在所用数据集上微调过的模型。这样,底层模型能更好地理解输入图像。
请注意,您也可以使用通过自监督预训练获得的检查点。该检查点不一定非要来自监督学习。事实上,如果预训练得当,自监督模型可以产生令人印象深刻的检索性能。
现在我们有了一个用于计算嵌入的模型,我们需要一些候选图像来进行查询。
加载候选图像数据集
过一段时间,我们将构建哈希表,将候选图像映射到哈希值。在查询时,我们将使用这些哈希表。我们将在相应的章节中详细讨论哈希表,但目前,为了获得一组候选图像,我们将使用 beans
数据集的 train
分割。
from datasets import load_dataset
dataset = load_dataset("beans")
这是训练分割中的一个样本:

该数据集有三个特征:
dataset["train"].features
>>> {'image_file_path': Value(dtype='string', id=None),
'image': Image(decode=True, id=None),
'labels': ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'], id=None)}
为了演示图像相似性系统,我们将使用候选图像数据集中的 100 个样本,以缩短总体运行时间。
num_samples = 100
seed = 42
candidate_subset = dataset["train"].shuffle(seed=seed).select(range(num_samples))
查找相似图像的过程
下面是获取相似图像过程的图示概览。

将上图分解一下,我们有:
- 从候选图像(
candidate_subset
)中提取嵌入,并将其存储在一个矩阵中。 - 获取查询图像并提取其嵌入。
- 迭代嵌入矩阵(在步骤 1 中计算),并计算查询嵌入和当前候选嵌入之间的相似性得分。我们通常维护一个类似字典的映射,以保持候选图像的某个标识符和相似性得分之间的对应关系。
- 根据相似性得分对映射结构进行排序,并返回底层标识符。我们使用这些标识符来获取候选样本。
我们可以编写一个简单的实用程序并将其 map()
到我们的候选图像数据集上,以高效地计算嵌入。
import torch
def extract_embeddings(model: torch.nn.Module):
"""Utility to compute embeddings."""
device = model.device
def pp(batch):
images = batch["image"]
# `transformation_chain` is a compostion of preprocessing
# transformations we apply to the input images to prepare them
# for the model. For more details, check out the accompanying Colab Notebook.
image_batch_transformed = torch.stack(
[transformation_chain(image) for image in images]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
return {"embeddings": embeddings}
return pp
我们可以像这样映射 extract_embeddings()
:
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = candidate_subset.map(extract_fn, batched=True, batch_size=batch_size)
接下来,为了方便起见,我们创建一个包含候选图像标识符的列表。
candidate_ids = []
for id in tqdm(range(len(candidate_subset_emb))):
label = candidate_subset_emb[id]["labels"]
# Create a unique indentifier.
entry = str(id) + "_" + str(label)
candidate_ids.append(entry)
我们将使用所有候选图像的嵌入矩阵来计算与查询图像的相似性得分。我们已经计算了候选图像的嵌入。在下一个单元格中,我们只是将它们收集到一个矩阵中。
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
我们将使用余弦相似度来计算两个嵌入向量之间的相似度分数。然后,我们将使用它来根据给定的查询样本获取相似的候选样本。
def compute_scores(emb_one, emb_two):
"""Computes cosine similarity between two vectors."""
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
return scores.numpy().tolist()
def fetch_similar(image, top_k=5):
"""Fetches the `top_k` similar images with `image` as the query."""
# Prepare the input query image for embedding computation.
image_transformed = transformation_chain(image).unsqueeze(0)
new_batch = {"pixel_values": image_transformed.to(device)}
# Comute the embedding.
with torch.no_grad():
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
# Compute similarity scores with all the candidate images at one go.
# We also create a mapping between the candidate image identifiers
# and their similarity scores with the query image.
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
similarity_mapping = dict(zip(candidate_ids, sim_scores))
# Sort the mapping dictionary and return `top_k` candidates.
similarity_mapping_sorted = dict(
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
)
id_entries = list(similarity_mapping_sorted.keys())[:top_k]
ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
labels = list(map(lambda x: int(x.split("_")[-1]), id_entries))
return ids, labels
执行查询
有了所有这些实用工具,我们就可以进行相似性搜索了。让我们从 beans
数据集的 test
分割中获取一个查询图像:
test_idx = np.random.choice(len(dataset["test"]))
test_sample = dataset["test"][test_idx]["image"]
test_label = dataset["test"][test_idx]["labels"]
sim_ids, sim_labels = fetch_similar(test_sample)
print(f"Query label: {test_label}")
print(f"Top 5 candidate labels: {sim_labels}")
导致
Query label: 0
Top 5 candidate labels: [0, 0, 0, 0, 0]
看来我们的系统找到了正确的相似图像集。当可视化时,我们会得到:

进一步的扩展和结论
我们现在有了一个可用的图像相似性系统。但在现实中,您将处理更多的候选图像。考虑到这一点,我们目前的程序有几个缺点:
- 如果我们将嵌入原样存储,内存需求会迅速飙升,尤其是在处理数百万张候选图像时。在我们的案例中,嵌入是 768 维的,在大规模场景下仍然相对较高。
- 高维嵌入对检索部分涉及的后续计算有直接影响。
如果我们能够在不干扰嵌入含义的情况下降低其维度,我们仍然可以在速度和检索质量之间保持良好的权衡。本文的配套 Colab Notebook实现了并演示了使用随机投影和局部敏感哈希实现这一点的实用工具。
🤗 Datasets 提供与 FAISS 的直接集成,这进一步简化了构建相似性系统的过程。假设您已经提取了候选图像(beans
数据集)的嵌入并将其存储在一个名为 embeddings
的特征中。您现在可以轻松使用数据集的 add_faiss_index()
来构建一个密集索引:
dataset_with_embeddings.add_faiss_index(column="embeddings")
一旦索引构建完成,dataset_with_embeddings
就可以用于使用 get_nearest_examples()
获取给定查询嵌入的最近邻示例。
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
"embeddings", qi_embedding, k=top_k
)
该方法返回得分和相应的候选示例。要了解更多信息,您可以查看官方文档和此笔记本。
最后,您可以尝试以下 Space,它构建了一个迷你图像相似性应用程序:
在这篇文章中,我们快速介绍了构建图像相似性系统。如果您觉得这篇文章很有趣,我们强烈建议您在此基础上进行构建,以便您能更熟悉其内部工作原理。
还在寻找更多学习资料?以下是一些对您可能有用的额外资源: