数据集文档

搜索索引

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

搜索索引

FAISSElasticsearch 可以搜索数据集中示例。当您想要检索与 NLP 任务相关的特定数据集示例时,这将非常有用。例如,如果您正在处理开放域问答任务,您可能只想返回与回答您的问题相关的示例。

本指南将向您展示如何为数据集构建索引,以便您可以搜索它。

FAISS

FAISS 基于文档的向量表示的相似性来检索文档。在本例中,您将使用 DPR 模型生成向量表示。

  1. 从 🤗 Transformers 下载 DPR 模型
>>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
>>> import torch
>>> torch.set_grad_enabled(False)
>>> ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
  1. 加载您的数据集并计算向量表示
>>> from datasets import load_dataset
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds_with_embeddings = ds.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example["line"], return_tensors="pt"))[0][0].numpy()})
  1. 使用 Dataset.add_faiss_index() 创建索引
>>> ds_with_embeddings.add_faiss_index(column='embeddings')
  1. 现在您可以使用 embeddings 索引查询您的数据集。加载 DPR 问题编码器,并使用 Dataset.get_nearest_examples() 搜索问题
>>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
>>> q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
>>> q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

>>> question = "Is it serious ?"
>>> question_embedding = q_encoder(**q_tokenizer(question, return_tensors="pt"))[0][0].numpy()
>>> scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embeddings', question_embedding, k=10)
>>> retrieved_examples["line"][0]
'_that_ serious? It is not serious at all. It’s simply a fantasy to amuse\r\n'
  1. 您可以使用 Dataset.get_index() 访问索引,并将其用于特殊操作,例如使用 range_search 查询它
>>> faiss_index = ds_with_embeddings.get_index('embeddings').faiss_index
>>> limits, distances, indices = faiss_index.range_search(x=question_embedding.reshape(1, -1), thresh=0.95)
  1. 查询完成后,使用 Dataset.save_faiss_index() 将索引保存到磁盘
>>> ds_with_embeddings.save_faiss_index('embeddings', 'my_index.faiss')
  1. 稍后使用 Dataset.load_faiss_index() 重新加载它
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds.load_faiss_index('embeddings', 'my_index.faiss')

Elasticsearch

与 FAISS 不同,Elasticsearch 基于精确匹配来检索文档。

在您的机器上启动 Elasticsearch,或者如果您尚未安装,请参阅 Elasticsearch 安装指南

  1. 加载您要索引的数据集
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
  1. 使用 Dataset.add_elasticsearch_index() 构建索引
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200")
  1. 然后,您可以使用 Dataset.get_nearest_examples() 查询 context 索引
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)
>>> retrieved_examples["title"][0]
'Computational_complexity_theory'
  1. 如果您想重用索引,请在构建索引时定义 es_index_name 参数
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> squad.get_index("context").es_index_name
hf_squad_val_context
  1. 稍后在调用 Dataset.load_elasticsearch_index() 时使用索引名称重新加载它
>>> from datasets import load_dataset
>>> squad = load_dataset('squad', split='validation')
>>> squad.load_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)

对于更高级的 Elasticsearch 用法,您可以使用自定义设置指定您自己的配置

>>> import elasticsearch as es
>>> import elasticsearch.helpers
>>> from elasticsearch import Elasticsearch
>>> es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])  # default client
>>> es_config = {
...     "settings": {
...         "number_of_shards": 1,
...         "analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
...     },
...     "mappings": {"properties": {"text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}}},
... }  # default config
>>> es_index_name = "hf_squad_context"  # name of the index in Elasticsearch
>>> squad.add_elasticsearch_index("context", es_client=es_client, es_config=es_config, es_index_name=es_index_name)
< > 在 GitHub 上更新