使用 FAISS 进行语义搜索
在 第 5 章 中,我们从 🤗 Datasets 库创建了一个包含 GitHub 问题和评论的数据集。在本节中,我们将使用这些信息构建一个搜索引擎,帮助我们找到有关该库的最紧迫问题的答案!
使用嵌入进行语义搜索
正如我们在 第 1 章 中看到的,基于 Transformer 的语言模型将文本片段中的每个标记表示为一个嵌入向量。事实证明,可以“池化”各个嵌入来创建整个句子、段落或(在某些情况下)文档的向量表示。然后,可以通过计算每个嵌入与查询之间的点积相似度(或其他一些相似度度量)并返回重叠最大的文档,来使用这些嵌入在语料库中查找类似的文档。
在本节中,我们将使用嵌入来开发一个语义搜索引擎。与基于将查询中的关键词与文档进行匹配的传统方法相比,这些搜索引擎提供了几个优势。
加载和准备数据集
首先,我们需要下载 GitHub 问题的 dataset,因此像往常一样使用 load_dataset()
函数
from datasets import load_dataset
issues_dataset = load_dataset("lewtun/github-issues", split="train")
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 2855
})
在这里,我们在 load_dataset()
中指定了默认的 train
分割,因此它返回一个 Dataset
而不是 DatasetDict
。首要任务是过滤掉 pull request,因为这些请求很少用于回答用户查询,并且会在我们的搜索引擎中引入噪声。正如我们现在应该熟悉的那样,我们可以使用 Dataset.filter()
函数来排除数据集中的这些行。趁此机会,我们还可以过滤掉没有评论的行,因为这些行没有提供对用户查询的任何答案
issues_dataset = issues_dataset.filter(
lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
num_rows: 771
})
我们可以看到 dataset 中有很多列,其中大多数我们不需要构建搜索引擎。从搜索的角度来看,信息量最大的列是 title
、body
和 comments
,而 html_url
为我们提供了指向源问题的链接。让我们使用 Dataset.remove_columns()
函数删除其余列
columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 771
})
为了创建嵌入,我们将为每个评论添加问题的标题和正文,因为这些字段通常包含有用的上下文信息。因为我们的 comments
列当前是每个问题的评论列表,所以我们需要“展开”该列,以便每一行都包含一个 (html_url, title, body, comment)
元组。在 Pandas 中,我们可以使用 DataFrame.explode()
函数 来实现这一点,该函数为列表型列中的每个元素创建一个新行,同时复制所有其他列值。为了查看其工作原理,让我们首先切换到 Pandas DataFrame
格式
issues_dataset.set_format("pandas")
df = issues_dataset[:]
如果我们检查此 DataFrame
中的第一行,我们可以看到与此问题关联的四个评论
df["comments"][0].tolist()
['the bug code locate in :\r\n if data_args.task_name is not None:\r\n # Downloading and loading a dataset from the hub.\r\n datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
'cannot connect,even by Web browser,please check that there is some problems。',
'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']
当我们展开 df
时,我们期望为每个评论获得一行。让我们检查一下是否如此
comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url | 标题 | 评论 | 正文 | |
---|---|---|---|---|
0 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: 无法访问 https://raw.githubusercontent.com | 错误代码位于:\r\n if data_args.task_name is not None... | 您好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
1 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: 无法访问 https://raw.githubusercontent.com | 你好 @jinec,\r\n\r\n我们不时会收到来自 github.com 网站的这种 `ConnectionError`:https://raw.githubusercontent.com... | 您好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
2 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: 无法访问 https://raw.githubusercontent.com | 无法连接,即使通过 Web 浏览器,请检查是否存在某些问题。 | 您好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
3 | https://github.com/huggingface/datasets/issues/2787 | ConnectionError: 无法访问 https://raw.githubusercontent.com | 我可以毫无问题地访问 https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py... | 您好,\r\n我正在尝试运行 run_glue.py,但它给了我这个错误... |
很好,我们可以看到行已被复制,并且 comments
列包含各个评论!现在我们已经完成了 Pandas 的操作,可以通过将 DataFrame
加载到内存中快速切换回 Dataset
from datasets import Dataset
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body'],
num_rows: 2842
})
好的,这给了我们几千条评论可以使用!
✏️ 试一试! 看看你是否可以使用 Dataset.map()
展开 issues_dataset
的 comments
列,无需诉诸于使用 Pandas。这有点棘手;你可能会发现 🤗 Datasets 文档的 “批处理映射” 部分对此任务很有用。
现在我们每个行一个评论了,让我们创建一个新的 comments_length
列,其中包含每个评论的单词数
comments_dataset = comments_dataset.map(
lambda x: {"comment_length": len(x["comments"].split())}
)
我们可以使用这个新列来过滤掉短评论,这些评论通常包含诸如“cc @lewtun”或“谢谢!”之类的信息,这些信息与我们的搜索引擎无关。没有精确的数字可用于选择过滤器,但大约 15 个词似乎是一个不错的起点
comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
num_rows: 2098
})
在清理了 dataset 之后,让我们在一个新的 text
列中将问题的标题、描述和评论连接在一起。像往常一样,我们将编写一个简单的函数,我们可以将其传递给 Dataset.map()
def concatenate_text(examples):
return {
"text": examples["title"]
+ " \n "
+ examples["body"]
+ " \n "
+ examples["comments"]
}
comments_dataset = comments_dataset.map(concatenate_text)
我们终于准备好创建一些嵌入向量了!让我们来看一下。
创建文本嵌入向量
我们在第 2 章中看到,我们可以使用AutoModel
类获取词元嵌入向量。我们只需要选择一个合适的检查点来加载模型。幸运的是,有一个名为sentence-transformers
的库专门用于创建嵌入向量。如库的文档中所述,我们的用例是非对称语义搜索的一个例子,因为我们有一个简短的查询,我们希望在较长的文档(如问题评论)中找到答案。文档中方便的模型概述表表明,multi-qa-mpnet-base-dot-v1
检查点在语义搜索方面具有最佳性能,因此我们将将其用于我们的应用程序。我们还将使用相同的检查点加载分词器。
from transformers import AutoTokenizer, AutoModel
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
为了加快嵌入向量过程,将模型和输入放在GPU设备上会有所帮助,所以我们现在就来做这件事。
import torch
device = torch.device("cuda")
model.to(device)
正如我们前面提到的,我们希望将GitHub问题语料库中的每个条目表示为一个单一的向量,因此我们需要以某种方式“池化”或平均我们的词元嵌入向量。一种流行的方法是对模型的输出执行CLS池化,我们只需收集特殊[CLS]
词元的最后一个隐藏状态。以下函数可以帮我们做到这一点。
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
接下来,我们将创建一个辅助函数,该函数将对文档列表进行分词,将张量放到GPU上,将其馈送到模型,最后对输出应用CLS池化。
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
我们可以通过将语料库中的第一个文本条目馈送到函数并检查输出形状来测试函数是否正常工作。
embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])
太好了,我们已经将语料库中的第一个条目转换为一个768维的向量!我们可以使用Dataset.map()
将我们的get_embeddings()
函数应用于语料库中的每一行,因此让我们创建一个新的embeddings
列,如下所示。
embeddings_dataset = comments_dataset.map(
lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)
请注意,我们已将嵌入向量转换为NumPy数组——这是因为当我们尝试使用FAISS对其进行索引时,🤗 Datasets需要这种格式,我们将在下一步中执行此操作。
使用FAISS进行高效相似性搜索
现在我们有了嵌入向量数据集,我们需要一些方法来搜索它们。为此,我们将使用🤗 Datasets中一个名为FAISS索引的特殊数据结构。FAISS(Facebook AI Similarity Search的缩写)是一个库,它提供了高效的算法来快速搜索和聚类嵌入向量。
FAISS背后的基本思想是创建一个称为索引的特殊数据结构,允许人们找到哪些嵌入向量与输入嵌入向量相似。在🤗 Datasets中创建FAISS索引很简单——我们使用Dataset.add_faiss_index()
函数并指定我们想要索引的数据集的哪一列。
embeddings_dataset.add_faiss_index(column="embeddings")
我们现在可以通过使用Dataset.get_nearest_examples()
函数进行最近邻查找来对该索引执行查询。让我们通过首先嵌入一个问题来测试一下,如下所示。
question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])
就像文档一样,我们现在有一个768维的向量表示查询,我们可以将其与整个语料库进行比较以找到最相似的嵌入向量。
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", question_embedding, k=5
)
Dataset.get_nearest_examples()
函数返回一个元组,其中包含对查询和文档之间重叠程度进行排序的分数,以及一组相应的样本(此处为5个最佳匹配)。让我们将它们收集到pandas.DataFrame
中,以便我们可以轻松地对其进行排序。
import pandas as pd
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
现在我们可以遍历前几行,看看我们的查询与可用的评论匹配程度如何。
for _, row in samples_df.iterrows():
print(f"COMMENT: {row.comments}")
print(f"SCORE: {row.scores}")
print(f"TITLE: {row.title}")
print(f"URL: {row.html_url}")
print("=" * 50)
print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.
@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`
We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.
Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)
I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.
----------
> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.
----------
About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.
SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`
HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""
还不错!我们的第二个命中似乎与查询相匹配。
✏️ 试试看!创建您自己的查询,看看您是否可以在检索到的文档中找到答案。您可能需要增加Dataset.get_nearest_examples()
中的k
参数以扩大搜索范围。