语义搜索RAG系统优化

社区文章 发布于2024年7月13日

文章正在撰写中... ✍️

本文参考了以下文章: https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3

导言

检索增强生成(RAG)是一种强大的技术,它结合了信息检索和文本生成,以产生更准确、更具上下文相关性的答案。在本文中,我们将探讨如何优化传统RAG系统以进行语义搜索,从而提高结果的质量和相关性。

为什么要为语义搜索进行优化?

语义搜索超越了简单的关键词匹配。它旨在理解查询的意图和上下文,即使文本中没有确切的术语,也能找到相关信息。这种方法对RAG系统特别有益,因为它可以检索更相关的信息来为生成模型提供支持。

语义优化关键修改

我们将研究对传统RAG系统进行修改以对其进行语义优化。我们将比较每个关键组件的初始版本和优化版本。

数据结构

在深入优化之前,这是我们的数据块结构

{
    "id": "01",
    "title": "…",
    "content": "…",
    "tags": ["…","…"]
}

这些块以.parquet格式存储,并发布到Hugging Face以便于使用。

此比较将分三个阶段进行。第一阶段:向数据集中添加嵌入;第二阶段:验证索引和检索;第三阶段:将RAG系统集成到Gradio中。

1. 向数据集中添加嵌入。

基本版本

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia")
dataset
from sentence_transformers import SentenceTransformer
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# embed the dataset
def embed(batch):
  # or you can combine multiple columns here, for example the title and the text
  information = batch["text"]
  return {"embeddings" : ST.encode(information)}
dataset = dataset.map(embed,batched=True,batch_size=16)
dataset.push_to_hub("not-lain/wikipedia", revision="embedded")

优化语义搜索版本

from datasets import load_dataset

dataset = load_dataset("path-dataset")
dataset
from sentence_transformers import SentenceTransformer

# Nous gardons le modèle original pour sa qualité d'embedding
ST = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def embed(batch):
    """
    Ajoute une colonne 'embeddings' au dataset en tenant compte de la structure spécifique des chunks
    """
    # Combinaison du titre, du contenu et des tags pour une représentation riche
    combined_info = []
    for item in batch:
        # Joindre les tags en une seule chaîne
        tags_string = " ".join(item['tags'])
        # Combiner titre, contenu et tags
        combined = f"{item['title']} {item['content']} {tags_string}"
        combined_info.append(combined)
    
    # Création et normalisation des embeddings
    embeddings = ST.encode(combined_info, normalize_embeddings=True)
    
    return {"embeddings": embeddings}

# Utilisation de map avec batching pour une meilleure efficacité
dataset = dataset.map(embed, batched=True, batch_size=16)
dataset.push_to_hub("path-dataset", revision="embedded")

所做的改进:

  • 嵌入模型已更改为更适合语义搜索的模型:“sentence-transformers/all-MiniLM-L6-v2”。
  • 结合标题和内容以获得更丰富的表示。
  • 归一化嵌入以提高向量比较的一致性。

2. 在数据集中进行语义搜索。

基本版本

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
def search(query: str, k: int = 3 ):
    """a function that embeds a new query and returns the most probable results"""
    embedded_query = ST.encode(query) # embed new query
    scores, retrieved_examples = data.get_nearest_examples( # retrieve results
        "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
        k=k # get only top k results
    )
    return scores, retrieved_examples
scores , result = search("anarchy", 4 ) # search for word anarchy and get the best 4 matching values from the dataset
# the lower the better
scores
result['title']
print(result["text"][0])

优化版本

from datasets import load_dataset

dataset = load_dataset("path-dataset",revision = "embedded")
data = dataset["train"]
# Optimisation : utilisation de la métrique du produit scalaire pour les embeddings normalisés
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT) #Vérifier que la fonction existe et qu'elle fonctionne.
def semantic_search(query: str, k: int = 3):
    # Normalisation de l'embedding de la requête
    embedded_query = ST.encode(query, normalize_embeddings=True)
    
    scores, retrieved_chunks = dataset.get_nearest_examples(
        "embeddings", embedded_query, k=k
    )
    
    results = []
    for score, chunk in zip(scores, retrieved_chunks):
        results.append({
            'score': score,
            'id': chunk['id'],
            'title': chunk['title'],
            'content': chunk['content'],
            'tags': chunk['tags'],
            'similarity': (1 + score) / 2  # Conversion de la similarité cosinus [-1, 1] à [0, 1]
        })
    
    results.sort(key=lambda x: x['similarity'], reverse=True)
    return results
query = "Quelle est l'identité de Lucas?"
results = semantic_search(query, k=3)

for result in results:
    print(f"ID: {result['id']}")
    print(f"Titre: {result['title']}")
    print(f"Similarité: {result['similarity']:.2f}")
    print(f"Tags: {', '.join(result['tags'])}")
    print(f"Contenu: {result['content'][:200]}...")
    print("---")

3. 集成到Gradio中。

基本版本


社区

注册登录 发表评论