语义搜索RAG系统优化
社区文章 发布于2024年7月13日
文章正在撰写中... ✍️
本文参考了以下文章: https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3
导言
检索增强生成(RAG)是一种强大的技术,它结合了信息检索和文本生成,以产生更准确、更具上下文相关性的答案。在本文中,我们将探讨如何优化传统RAG系统以进行语义搜索,从而提高结果的质量和相关性。
为什么要为语义搜索进行优化?
语义搜索超越了简单的关键词匹配。它旨在理解查询的意图和上下文,即使文本中没有确切的术语,也能找到相关信息。这种方法对RAG系统特别有益,因为它可以检索更相关的信息来为生成模型提供支持。
语义优化关键修改
我们将研究对传统RAG系统进行修改以对其进行语义优化。我们将比较每个关键组件的初始版本和优化版本。
数据结构
在深入优化之前,这是我们的数据块结构
{
"id": "01",
"title": "…",
"content": "…",
"tags": ["…","…"]
}
这些块以.parquet格式存储,并发布到Hugging Face以便于使用。
此比较将分三个阶段进行。第一阶段:向数据集中添加嵌入;第二阶段:验证索引和检索;第三阶段:将RAG系统集成到Gradio中。
1. 向数据集中添加嵌入。
基本版本
from datasets import load_dataset
dataset = load_dataset("not-lain/wikipedia")
dataset
from sentence_transformers import SentenceTransformer
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# embed the dataset
def embed(batch):
# or you can combine multiple columns here, for example the title and the text
information = batch["text"]
return {"embeddings" : ST.encode(information)}
dataset = dataset.map(embed,batched=True,batch_size=16)
dataset.push_to_hub("not-lain/wikipedia", revision="embedded")
优化语义搜索版本
from datasets import load_dataset
dataset = load_dataset("path-dataset")
dataset
from sentence_transformers import SentenceTransformer
# Nous gardons le modèle original pour sa qualité d'embedding
ST = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def embed(batch):
"""
Ajoute une colonne 'embeddings' au dataset en tenant compte de la structure spécifique des chunks
"""
# Combinaison du titre, du contenu et des tags pour une représentation riche
combined_info = []
for item in batch:
# Joindre les tags en une seule chaîne
tags_string = " ".join(item['tags'])
# Combiner titre, contenu et tags
combined = f"{item['title']} {item['content']} {tags_string}"
combined_info.append(combined)
# Création et normalisation des embeddings
embeddings = ST.encode(combined_info, normalize_embeddings=True)
return {"embeddings": embeddings}
# Utilisation de map avec batching pour une meilleure efficacité
dataset = dataset.map(embed, batched=True, batch_size=16)
dataset.push_to_hub("path-dataset", revision="embedded")
所做的改进:
- 嵌入模型已更改为更适合语义搜索的模型:“sentence-transformers/all-MiniLM-L6-v2”。
- 结合标题和内容以获得更丰富的表示。
- 归一化嵌入以提高向量比较的一致性。
2. 在数据集中进行语义搜索。
基本版本
from datasets import load_dataset
dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
def search(query: str, k: int = 3 ):
"""a function that embeds a new query and returns the most probable results"""
embedded_query = ST.encode(query) # embed new query
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
"embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
k=k # get only top k results
)
return scores, retrieved_examples
scores , result = search("anarchy", 4 ) # search for word anarchy and get the best 4 matching values from the dataset
# the lower the better
scores
result['title']
print(result["text"][0])
优化版本
from datasets import load_dataset
dataset = load_dataset("path-dataset",revision = "embedded")
data = dataset["train"]
# Optimisation : utilisation de la métrique du produit scalaire pour les embeddings normalisés
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT) #Vérifier que la fonction existe et qu'elle fonctionne.
def semantic_search(query: str, k: int = 3):
# Normalisation de l'embedding de la requête
embedded_query = ST.encode(query, normalize_embeddings=True)
scores, retrieved_chunks = dataset.get_nearest_examples(
"embeddings", embedded_query, k=k
)
results = []
for score, chunk in zip(scores, retrieved_chunks):
results.append({
'score': score,
'id': chunk['id'],
'title': chunk['title'],
'content': chunk['content'],
'tags': chunk['tags'],
'similarity': (1 + score) / 2 # Conversion de la similarité cosinus [-1, 1] à [0, 1]
})
results.sort(key=lambda x: x['similarity'], reverse=True)
return results
query = "Quelle est l'identité de Lucas?"
results = semantic_search(query, k=3)
for result in results:
print(f"ID: {result['id']}")
print(f"Titre: {result['title']}")
print(f"Similarité: {result['similarity']:.2f}")
print(f"Tags: {', '.join(result['tags'])}")
print(f"Contenu: {result['content'][:200]}...")
print("---")
3. 集成到Gradio中。
基本版本