NLP 课程文档

使用 FAISS 进行语义搜索

Hugging Face's logo
加入 Hugging Face 社区

并获取增强型文档体验的访问权限

开始

使用 FAISS 进行语义搜索

Ask a Question Open In Colab Open In Studio Lab

第 5 节 中,我们创建了 🤗 数据集存储库中 GitHub 问题和评论的数据集。在本节中,我们将使用此信息构建一个搜索引擎,该引擎可以帮助我们找到有关该库的最紧迫问题的答案!

使用嵌入进行语义搜索

正如我们在 第 1 章 中看到的,基于 Transformer 的语言模型将文本跨度中的每个标记表示为一个嵌入向量。事实证明,可以“池化”各个嵌入以创建整个句子、段落或(在某些情况下)文档的向量表示。然后,这些嵌入可用于通过计算每个嵌入之间的点积相似度(或其他相似度度量)来查找语料库中的相似文档,并返回重叠度最大的文档。

在本节中,我们将使用嵌入来开发语义搜索引擎。这些搜索引擎比基于将查询中的关键字与文档匹配的传统方法具有多个优势。

Semantic search.

加载和准备数据集

我们要做的第一件事是下载我们的 GitHub 问题数据集,因此像往常一样使用 load_dataset() 函数

from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

在这里,我们已在 load_dataset() 中指定了默认的 train 拆分,因此它返回 Dataset 而不是 DatasetDict。首要任务是过滤掉拉取请求,因为这些请求很少用于回答用户查询,并且会在我们的搜索引擎中引入噪声。正如我们现在所熟悉的,我们可以使用 Dataset.filter() 函数来排除数据集中这些行。趁热打铁,我们还将过滤掉没有评论的行,因为这些行没有为用户查询提供答案

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

我们可以看到数据集中有很多列,其中大多数我们不需要构建搜索引擎。从搜索的角度来看,最具信息量的列是 titlebodycomments,而 html_url 为我们提供了指向源问题的链接。让我们使用 Dataset.remove_columns() 函数来删除其余的列

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

为了创建我们的嵌入,我们将为每条评论添加问题的标题和正文,因为这些字段通常包含有用的上下文信息。由于我们的 comments 列当前是每个问题的评论列表,因此我们需要“展开”该列,以便每行都包含 (html_url, title, body, comment) 元组。在 Pandas 中,我们可以使用 DataFrame.explode() 函数 来完成此操作,该函数为列表式列中的每个元素创建一个新行,同时复制所有其他列值。为了看到它的实际效果,让我们先切换到 Pandas DataFrame 格式

issues_dataset.set_format("pandas")
df = issues_dataset[:]

如果我们检查此 DataFrame 中的第一行,我们可以看到与该问题关联的四条评论

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

当我们展开 df 时,我们希望为这些评论中的每一个获得一行。让我们检查一下是否确实如此

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError:无法访问 https://raw.githubusercontent.com 错误代码位于:\r\n if data_args.task_name is not None... 您好,\r\n我正在尝试运行 run_glue.py,但它给了我此错误...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError:无法访问 https://raw.githubusercontent.com 您好 @jinec,\r\n\r\n我们不时地会遇到来自 github.com 网站的这种 `ConnectionError`:https://raw.githubusercontent.com... 您好,\r\n我正在尝试运行 run_glue.py,但它给了我此错误...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError:无法访问 https://raw.githubusercontent.com 无法连接,即使通过 Web 浏览器,也请检查是否有一些问题。 您好,\r\n我正在尝试运行 run_glue.py,但它给了我此错误...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError:无法访问 https://raw.githubusercontent.com 我可以在没有问题的情况下访问 https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py... 您好,\r\n我正在尝试运行 run_glue.py,但它给了我此错误...

很好,我们可以看到这些行已被复制,comments 列包含各个评论!现在我们已经完成了 Pandas 的操作,可以通过将 DataFrame 加载到内存中快速切换回 Dataset

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

好的,这给了我们几千条评论可以处理!

✏️ 试一试! 看看是否可以使用 Dataset.map() 来展开 issues_datasetcomments而无需使用 Pandas。这有点棘手;你可能会发现 🤗 数据集文档的 “批处理映射” 部分对完成此任务很有用。

现在我们每行都有一条评论,让我们创建一个新的 comments_length 列,其中包含每条评论的词数

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

我们可以使用此新列来过滤掉简短的评论,这些评论通常包括“cc @lewtun”或“谢谢!”之类的内容,这些内容与我们的搜索引擎无关。没有一个确切的数字可以用来选择过滤器,但大约 15 个单词似乎是一个不错的起点

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

整理完数据集后,让我们在新 text 列中将问题的标题、描述和评论连接在一起。像往常一样,我们将编写一个简单函数,可以将其传递给 Dataset.map()

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

我们终于准备好创建一些嵌入!让我们来看看。

创建文本嵌入

我们在第 2 章中看到,我们可以使用 AutoModel 类获取词元嵌入。我们所要做的就是选择一个合适的检查点来加载模型。幸运的是,有一个名为 sentence-transformers 的库专门用于创建嵌入。正如该库的文档中所述,我们的用例是非对称语义搜索的一个例子,因为我们有一个短查询,我们希望在较长的文档中找到它的答案,例如一个问题评论。文档中的便捷模型概览表表明 multi-qa-mpnet-base-dot-v1 检查点在语义搜索方面具有最佳性能,因此我们将把它用于我们的应用程序。我们还将使用相同的检查点加载分词器

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

为了加快嵌入过程,将模型和输入放在 GPU 设备上会有所帮助,所以我们现在就来做吧

import torch

device = torch.device("cuda")
model.to(device)

正如我们之前提到的,我们希望将 GitHub 问题语料库中的每个条目表示为一个单个向量,因此我们需要以某种方式“池化”或平均我们的词元嵌入。一种流行的方法是对模型的输出执行CLS 池化,在这里我们只需收集特殊 [CLS] 词元的最后一个隐藏状态。以下函数为我们完成了这项工作

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

接下来,我们将创建一个辅助函数,该函数将对文档列表进行分词,将张量放在 GPU 上,将其馈送到模型,最后将 CLS 池化应用于输出

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

我们可以通过将语料库中的第一个文本条目馈送到该函数并检查输出形状来测试该函数是否有效

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

太好了,我们已经将语料库中的第一个条目转换为一个 768 维向量!我们可以使用 Dataset.map() 将我们的 get_embeddings() 函数应用于语料库中的每一行,所以让我们创建一个新的 embeddings 列,如下所示

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

请注意,我们已经将嵌入转换为 NumPy 数组——这是因为 🤗 Datasets 在我们尝试使用 FAISS 对其进行索引时需要这种格式,我们将在下一步进行。

使用 FAISS 进行高效的相似性搜索

现在我们已经有了嵌入数据集,我们需要一种方法来搜索它们。为此,我们将使用 🤗 Datasets 中的一种特殊数据结构,称为FAISS 索引FAISS(代表 Facebook AI Similarity Search)是一个库,它提供了高效的算法来快速搜索和聚类嵌入向量。

FAISS 背后的基本思想是创建一个称为索引的特殊数据结构,该结构允许人们找到哪些嵌入与输入嵌入相似。在 🤗 Datasets 中创建 FAISS 索引很简单——我们使用 Dataset.add_faiss_index() 函数并指定我们想要索引的数据集的哪个列

embeddings_dataset.add_faiss_index(column="embeddings")

我们现在可以通过使用 Dataset.get_nearest_examples() 函数进行最近邻查找来对该索引执行查询。让我们通过首先将问题嵌入来测试一下

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

就像文档一样,我们现在有一个 768 维向量来表示查询,我们可以将其与整个语料库进行比较以找到最相似的嵌入

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

Dataset.get_nearest_examples() 函数返回一个分数元组,该元组对查询和文档之间的重叠进行排名,以及一组相应的样本(这里,是 5 个最佳匹配)。让我们将它们收集到一个 pandas.DataFrame 中,这样我们就可以轻松地对其进行排序

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

现在我们可以遍历前几行以查看我们的查询与可用的评论匹配程度如何

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

还不错!我们的第二个命中似乎与查询匹配。

✏️ 尝试一下!创建自己的查询并查看你是否可以在检索到的文档中找到答案。你可能需要增加 Dataset.get_nearest_examples() 中的 k 参数以扩大搜索范围。