LLM 课程文档
使用 FAISS 进行语义搜索
并获得增强的文档体验
开始使用
使用 FAISS 进行语义搜索
在第 5 节中,我们从 🤗 Datasets 仓库创建了一个 GitHub issues 和评论的数据集。在本节中,我们将使用这些信息构建一个搜索引擎,它可以帮助我们找到关于该库最紧迫问题的答案!
使用嵌入向量进行语义搜索
正如我们在第 1 章中看到的,基于 Transformer 的语言模型将一段文本中的每个 token 表示为一个嵌入向量。事实证明,可以“池化”各个嵌入向量,为整个句子、段落或(在某些情况下)文档创建向量表示。然后,可以通过计算每个嵌入向量之间的点积相似度(或某些其他相似度指标)来使用这些嵌入向量在语料库中查找相似的文档,并返回重叠程度最高的文档。
在本节中,我们将使用嵌入向量开发一个语义搜索引擎。与传统的基于查询中的关键字与文档匹配的方法相比,这些搜索引擎具有多个优势。
加载和准备数据集
我们需要做的第一件事是下载我们的 GitHub issues 数据集,所以让我们像往常一样使用 load_dataset()
函数
from datasets import load_dataset
issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 2855
})
在这里,我们在 load_dataset()
中指定了默认的 train
拆分,因此它返回的是 Dataset
而不是 DatasetDict
。首要任务是过滤掉 pull 请求,因为这些请求很少用于回答用户查询,并且会在我们的搜索引擎中引入噪声。现在应该很熟悉了,我们可以使用 Dataset.filter()
函数来排除数据集中的这些行。同时,让我们也过滤掉没有评论的行,因为这些行没有提供用户查询的答案
issues_dataset = issues_dataset.filter(
lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 771
})
我们可以看到我们的数据集中有很多列,其中大多数是我们构建搜索引擎不需要的。从搜索的角度来看,信息量最大的列是 title
、body
和 comments
,而 html_url
为我们提供了返回源 issue 的链接。让我们使用 Dataset.remove_columns()
函数删除其余列
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 771
})
为了创建我们的嵌入向量,我们将 issue 的标题和正文添加到每个评论中,因为这些字段通常包含有用的上下文信息。由于我们的 comments
列当前是每个 issue 的评论列表,我们需要“展开”该列,以便每行都包含一个 (html_url, title, body, comment)
元组。在 Pandas 中,我们可以使用 DataFrame.explode()
函数来执行此操作,该函数为类似列表的列中的每个元素创建一个新行,同时复制所有其他列值。为了了解它的实际效果,让我们首先切换到 Pandas DataFrame
格式
issues_dataset.set_format("pandas")
df = issues_dataset[:]
如果我们检查此 DataFrame
中的第一行,我们可以看到与此 issue 关联的四个评论
df["comments"][0].tolist()
['the bug code locate in :\r\n if data_args.task_name is not None:\r\n # Downloading and loading a dataset from the hub.\r\n datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
'cannot connect,even by Web browser,please check that there is some problems。',
'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']
当我们展开 df
时,我们期望为每个评论获得一行。让我们检查一下是否是这种情况
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url | 标题 | 评论 | 正文 | |
---|---|---|---|---|
0 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | bug 代码位于:\r\n if data_args.task_name is not None... | 你好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
1 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | 嗨 @jinec,\r\n\r\n我们时不时会收到来自 github.com 网站的这种 `ConnectionError`:https://raw.githubusercontent.com... | 你好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
2 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | 无法连接,即使通过 Web 浏览器,请检查是否存在一些问题。 | 你好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
3 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: Couldn't reach https://raw.githubusercontent.com | 我可以毫无问题地访问 https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py... | 你好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
太棒了,我们可以看到行已被复制,comments
列包含各个评论!既然我们已经完成了 Pandas 的操作,我们可以通过将 DataFrame
加载到内存中来快速切换回 Dataset
from datasets import Dataset
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 2842
})
好的,这为我们提供了数千条评论来处理!
✏️ 试试看! 看看您是否可以使用 Dataset.map()
来展开 issues_dataset
的 comments
列,而无需求助于 Pandas 的使用。这有点棘手;您可能会发现 🤗 Datasets 文档的 “批量映射” 部分对此任务很有用。
现在我们每行有一个评论,让我们创建一个新的 comments_length
列,其中包含每个评论的字数
comments_dataset = comments_dataset.map(
lambda x: {"comment_length": len(x["comments"].split())}
)
我们可以使用这个新列来过滤掉简短的评论,这些评论通常包含诸如“cc @lewtun”或“谢谢!”之类的与我们的搜索引擎无关的内容。对于过滤器,没有精确的数字可供选择,但大约 15 个字似乎是一个不错的起点
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
num_rows: 2098
})
清理完数据集后,让我们将 issue 标题、描述和评论连接在一起,放在一个新的 text
列中。像往常一样,我们将编写一个简单的函数,我们可以将其传递给 Dataset.map()
def concatenate_text(examples):
return {
"text": examples["title"]
+ " \n "
+ examples["body"]
+ " \n "
+ examples["comments"]
}
comments_dataset = comments_dataset.map(concatenate_text)
我们终于准备好创建一些嵌入向量了!让我们看看。
创建文本嵌入向量
我们在第 2 章中看到,我们可以通过使用 AutoModel
类获得 token 嵌入向量。我们所需要做的就是选择一个合适的 checkpoint 来从中加载模型。幸运的是,有一个名为 sentence-transformers
的库专门用于创建嵌入向量。正如该库的文档中所述,我们的用例是非对称语义搜索的一个示例,因为我们有一个简短的查询,我们想在较长的文档(如 issue 评论)中找到其答案。模型概览表中的便捷信息表明,multi-qa-mpnet-base-dot-v1
checkpoint 在语义搜索方面具有最佳性能,因此我们将使用它来用于我们的应用程序。我们还将使用相同的 checkpoint 加载分词器
from transformers import AutoTokenizer, AutoModel
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
为了加快嵌入过程,将模型和输入放置在 GPU 设备上有助于提高速度,所以现在就来做这件事
import torch
device = torch.device("cuda")
model.to(device)
正如我们前面提到的,我们希望将 GitHub issues 语料库中的每个条目表示为单个向量,因此我们需要以某种方式“池化”或平均我们的 token 嵌入向量。一种流行的方法是对我们模型的输出执行 CLS 池化,我们只需收集特殊 [CLS]
token 的最后一个隐藏状态。以下函数可以为我们完成这项工作
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
接下来,我们将创建一个辅助函数,该函数将对文档列表进行 token 化,将张量放置在 GPU 上,将其馈送到模型,最后将 CLS 池化应用于输出
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
我们可以通过将语料库中的第一个文本条目馈送到该函数并检查输出形状来测试该函数是否有效
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])
太棒了,我们已将语料库中的第一个条目转换为 768 维向量!我们可以使用 Dataset.map()
将我们的 get_embeddings()
函数应用于语料库中的每一行,因此让我们创建一个新的 embeddings
列,如下所示
embeddings_dataset = comments_dataset.map(
lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)
请注意,我们已将嵌入向量转换为 NumPy 数组 —— 这是因为当我们尝试使用 FAISS 对它们进行索引时,🤗 Datasets 需要这种格式,我们将在下一步执行此操作。
使用 FAISS 进行高效相似性搜索
现在我们有了一个嵌入向量数据集,我们需要某种方法来搜索它们。为此,我们将使用 🤗 Datasets 中称为 FAISS 索引的特殊数据结构。FAISS(Facebook AI Similarity Search 的缩写)是一个库,它提供了高效的算法来快速搜索和聚类嵌入向量。
FAISS 背后的基本思想是创建一个称为索引的特殊数据结构,该结构允许人们找到哪些嵌入向量与输入嵌入向量相似。在 🤗 Datasets 中创建 FAISS 索引很简单 —— 我们使用 Dataset.add_faiss_index()
函数并指定我们要索引的数据集的哪一列
embeddings_dataset.add_faiss_index(column="embeddings")
我们现在可以通过使用 Dataset.get_nearest_examples()
函数执行最近邻查找来对此索引执行查询。让我们通过首先嵌入一个问题来测试一下
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])
与文档一样,我们现在有一个 768 维向量来表示查询,我们可以将其与整个语料库进行比较,以找到最相似的嵌入向量
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", question_embedding, k=5
)
Dataset.get_nearest_examples()
函数返回一个分数元组,该元组对查询和文档之间的重叠进行排名,以及一组相应的样本(此处为 5 个最佳匹配项)。让我们将这些收集到 pandas.DataFrame
中,以便我们可以轻松地对它们进行排序
import pandas as pd
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
现在我们可以迭代前几行,看看我们的查询与可用的评论匹配得如何
for _, row in samples_df.iterrows():
print(f"COMMENT: {row.comments}")
print(f"SCORE: {row.scores}")
print(f"TITLE: {row.title}")
print(f"URL: {row.html_url}")
print("=" * 50)
print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.
@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`
We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.
Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)
I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.
----------
> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.
----------
About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.
SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`
HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""
还不错!我们的第二个命中似乎与查询匹配。
✏️ 试试看! 创建您自己的查询,看看您是否可以在检索到的文档中找到答案。您可能需要在 Dataset.get_nearest_examples()
中增加 k
参数以扩大搜索范围。