LLM 课程文档

使用 FAISS 进行语义搜索

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

使用 FAISS 进行语义搜索

Ask a Question Open In Colab Open In Studio Lab

第 5 节中,我们从 🤗 Datasets 仓库创建了一个 GitHub issue 和评论数据集。在本节中,我们将使用这些信息构建一个搜索引擎,帮助我们找到有关该库最紧迫问题的答案!

使用嵌入进行语义搜索

第 1 章所述,基于 Transformer 的语言模型将文本段落中的每个 token 表示为嵌入向量。事实证明,可以通过“池化”单个嵌入来创建整个句子、段落或(在某些情况下)文档的向量表示。然后,这些嵌入可以通过计算每个嵌入之间的点积相似度(或一些其他相似度度量)来查找语料库中的相似文档,并返回重叠度最大的文档。

在本节中,我们将使用嵌入来开发一个语义搜索引擎。这些搜索引擎比基于查询中关键字与文档匹配的传统方法具有多个优势。

Semantic search.

加载和准备数据集

我们首先需要下载 GitHub issue 数据集,所以我们像往常一样使用 load_dataset() 函数

from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

这里我们在 load_dataset() 中指定了默认的 train 分割,因此它返回一个 Dataset 而不是 DatasetDict。首要任务是过滤掉拉取请求,因为这些请求通常很少用于回答用户查询,并且会在我们的搜索引擎中引入噪音。现在应该很熟悉了,我们可以使用 Dataset.filter() 函数来排除数据集中的这些行。同时,我们还会过滤掉没有评论的行,因为这些行无法为用户查询提供答案

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

我们可以看到数据集中有很多列,其中大部分我们不需要构建搜索引擎。从搜索角度来看,信息最丰富的列是 titlebodycomments,而 html_url 为我们提供了指向源 issue 的链接。让我们使用 Dataset.remove_columns() 函数删除其余列

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

为了创建嵌入,我们将为每个评论添加 issue 的标题和正文,因为这些字段通常包含有用的上下文信息。由于我们的 comments 列目前是每个 issue 的评论列表,我们需要“展开”该列,以便每行由一个 (html_url, title, body, comment) 元组组成。在 Pandas 中,我们可以使用 DataFrame.explode() 函数来做到这一点,该函数会为类列表列中的每个元素创建新行,同时复制所有其他列值。为了查看实际效果,我们首先切换到 Pandas DataFrame 格式

issues_dataset.set_format("pandas")
df = issues_dataset[:]

如果我们检查此 DataFrame 的第一行,我们可以看到有四个与此 issue 相关的评论

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

当我们展开 df 时,我们期望为每个评论获得一行。让我们检查一下是否如此

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError: 无法访问 https://raw.githubusercontent.com 错误代码位于:\r\n if data_args.task_name is not None... 你好,\r\n我正在尝试运行 run_glue.py,它给了我这个错误...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError: 无法访问 https://raw.githubusercontent.com 你好 @jinec,\r\n\r\n我们偶尔会从 github.com 网站收到这种 `ConnectionError`:https://raw.githubusercontent.com... 你好,\r\n我正在尝试运行 run_glue.py,它给了我这个错误...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError: 无法访问 https://raw.githubusercontent.com 无法连接,甚至通过 Web 浏览器也无法连接,请检查是否存在一些问题。 你好,\r\n我正在尝试运行 run_glue.py,它给了我这个错误...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError: 无法访问 https://raw.githubusercontent.com 我可以访问 https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py 而没有问题... 你好,\r\n我正在尝试运行 run_glue.py,它给了我这个错误...

太棒了,我们可以看到行已被复制,comments 列包含单个评论!既然我们已经完成了 Pandas 的工作,我们可以通过在内存中加载 DataFrame 快速切换回 Dataset

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

好的,这给了我们几千条评论来处理!

✏️ 试一试! 看看你是否可以使用 Dataset.map() 展开 issues_datasetcomments 列,而使用 Pandas。这有点棘手;你可能会发现 🤗 Datasets 文档的“批处理映射”部分对这项任务很有用。

现在我们每行有一个评论,让我们创建一个新的 comments_length 列,其中包含每个评论的字数

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

我们可以使用这个新列来过滤掉短评论,这些评论通常包含“cc @lewtun”或“谢谢!”等与我们的搜索引擎无关的内容。过滤的精确数字没有定论,但大约 15 个单词似乎是一个不错的开始

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

清理完数据集后,让我们将 issue 标题、描述和评论串联到一个新的 text 列中。像往常一样,我们将编写一个简单的函数,可以将其传递给 Dataset.map()

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

我们终于准备好创建嵌入了!让我们来看看。

创建文本嵌入

我们在第 2 章中看到,我们可以通过使用 AutoModel 类来获取 token 嵌入。我们只需要选择一个合适的检查点来加载模型。幸运的是,有一个名为 sentence-transformers 的库专门用于创建嵌入。正如该库的文档所述,我们的用例是非对称语义搜索的一个例子,因为我们有一个简短的查询,我们希望在较长的文档(如 issue 评论)中找到答案。文档中方便的模型概览表表明,multi-qa-mpnet-base-dot-v1 检查点在语义搜索方面具有最佳性能,因此我们将将其用于我们的应用程序。我们还将使用相同的检查点加载分词器

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

为了加快嵌入过程,将模型和输入放置在 GPU 设备上有助于提速,所以我们现在就这样做

import torch

device = torch.device("cuda")
model.to(device)

如前所述,我们希望将 GitHub issue 语料库中的每个条目表示为单个向量,因此我们需要以某种方式“池化”或平均我们的 token 嵌入。一种流行的方法是对模型的输出执行 CLS 池化,我们只需收集特殊 [CLS] token 的最后一个隐藏状态。以下函数为我们解决了这个问题

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

接下来,我们将创建一个辅助函数,该函数将对文档列表进行标记化,将张量放在 GPU 上,将其馈送到模型,最后对输出应用 CLS 池化

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

我们可以通过将语料库中的第一个文本条目输入函数并检查输出形状来测试函数是否正常工作

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

太棒了,我们已经将语料库中的第一个条目转换成一个 768 维向量!我们可以使用 Dataset.map() 将我们的 get_embeddings() 函数应用于语料库中的每一行,所以我们接下来创建一个新的 embeddings 列,如下所示

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

请注意,我们已将嵌入转换为 NumPy 数组 — 这是因为 🤗 Datasets 要求在将它们与 FAISS 建立索引时使用这种格式,我们将在下一步进行。

使用 FAISS 进行高效相似度搜索

现在我们有了一个嵌入数据集,我们需要某种方式来搜索它们。为此,我们将使用 🤗 Datasets 中的一种特殊数据结构,称为 FAISS 索引FAISS(Facebook AI Similarity Search 的缩写)是一个提供高效算法以快速搜索和聚类嵌入向量的库。

FAISS 的基本思想是创建一个名为索引的特殊数据结构,该结构允许查找哪些嵌入与输入嵌入相似。在 🤗 Datasets 中创建 FAISS 索引很简单——我们使用 Dataset.add_faiss_index() 函数并指定我们希望索引的数据集的哪一列

embeddings_dataset.add_faiss_index(column="embeddings")

现在我们可以通过使用 Dataset.get_nearest_examples() 函数进行最近邻查找来对该索引执行查询。让我们通过首先嵌入一个问题来测试一下

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

就像处理文档一样,我们现在有一个 768 维向量表示查询,我们可以将其与整个语料库进行比较,以找到最相似的嵌入

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

Dataset.get_nearest_examples() 函数返回一个包含查询与文档之间重叠程度的得分元组,以及一组相应的样本(此处为 5 个最佳匹配)。让我们将它们收集到 pandas.DataFrame 中,以便我们轻松排序

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

现在我们可以遍历前几行,看看我们的查询与可用评论的匹配程度

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

还不错!我们的第二个结果似乎与查询匹配。

✏️ 试一试! 创建你自己的查询,看看是否能在检索到的文档中找到答案。你可能需要增加 Dataset.get_nearest_examples() 中的 k 参数来扩大搜索范围。

< > 在 GitHub 上更新