搜索索引
FAISS 和 Elasticsearch 可以搜索数据集中示例。当您想要检索与 NLP 任务相关的特定数据集示例时,这将非常有用。例如,如果您正在处理开放域问答任务,您可能只想返回与回答您的问题相关的示例。
本指南将向您展示如何为数据集构建索引,以便您可以搜索它。
FAISS
FAISS 基于文档的向量表示的相似性来检索文档。在本例中,您将使用 DPR 模型生成向量表示。
- 从 🤗 Transformers 下载 DPR 模型
>>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
>>> import torch
>>> torch.set_grad_enabled(False)
>>> ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
- 加载您的数据集并计算向量表示
>>> from datasets import load_dataset
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds_with_embeddings = ds.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example["line"], return_tensors="pt"))[0][0].numpy()})
- 使用 Dataset.add_faiss_index() 创建索引
>>> ds_with_embeddings.add_faiss_index(column='embeddings')
- 现在您可以使用
embeddings
索引查询您的数据集。加载 DPR 问题编码器,并使用 Dataset.get_nearest_examples() 搜索问题
>>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
>>> q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
>>> q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
>>> question = "Is it serious ?"
>>> question_embedding = q_encoder(**q_tokenizer(question, return_tensors="pt"))[0][0].numpy()
>>> scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embeddings', question_embedding, k=10)
>>> retrieved_examples["line"][0]
'_that_ serious? It is not serious at all. It’s simply a fantasy to amuse\r\n'
- 您可以使用 Dataset.get_index() 访问索引,并将其用于特殊操作,例如使用
range_search
查询它
>>> faiss_index = ds_with_embeddings.get_index('embeddings').faiss_index
>>> limits, distances, indices = faiss_index.range_search(x=question_embedding.reshape(1, -1), thresh=0.95)
- 查询完成后,使用 Dataset.save_faiss_index() 将索引保存到磁盘
>>> ds_with_embeddings.save_faiss_index('embeddings', 'my_index.faiss')
- 稍后使用 Dataset.load_faiss_index() 重新加载它
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds.load_faiss_index('embeddings', 'my_index.faiss')
Elasticsearch
与 FAISS 不同,Elasticsearch 基于精确匹配来检索文档。
在您的机器上启动 Elasticsearch,或者如果您尚未安装,请参阅 Elasticsearch 安装指南。
- 加载您要索引的数据集
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200")
- 然后,您可以使用 Dataset.get_nearest_examples() 查询
context
索引
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)
>>> retrieved_examples["title"][0]
'Computational_complexity_theory'
- 如果您想重用索引,请在构建索引时定义
es_index_name
参数
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> squad.get_index("context").es_index_name
hf_squad_val_context
- 稍后在调用 Dataset.load_elasticsearch_index() 时使用索引名称重新加载它
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
>>> squad.load_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)
对于更高级的 Elasticsearch 用法,您可以使用自定义设置指定您自己的配置
>>> import elasticsearch as es
>>> import elasticsearch.helpers
>>> from elasticsearch import Elasticsearch
>>> es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) # default client
>>> es_config = {
... "settings": {
... "number_of_shards": 1,
... "analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
... },
... "mappings": {"properties": {"text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}}},
... } # default config
>>> es_index_name = "hf_squad_context" # name of the index in Elasticsearch
>>> squad.add_elasticsearch_index("context", es_client=es_client, es_config=es_config, es_index_name=es_index_name)