由 SQL 和 Jina Reranker v2 支持的 RAG
作者:Scott Martens @ Jina AI
此笔记本将向您展示如何创建一个简单的检索增强生成 (RAG) 系统,该系统使用 SQL 数据库而不是从文档存储库中获取信息。
工作原理
- 给定一个 SQL 数据库,我们将提取 SQL 表定义(SQL 转储中的
CREATE
行)并将其存储起来。在本教程中,我们已经为您完成了此部分,定义存储在内存中作为列表。从该示例扩展可能需要更复杂的存储。 - 用户以自然语言输入查询。
- Jina Reranker v2(
jinaai/jina-reranker-v2-base-multilingual
),来自 Jina AI 的一个了解 SQL 的重新排序模型,将表定义按其与用户查询的相关性排序。 - 我们将 Mistral 7B Instruct v0.1 (
mistralai/Mistral-7B-Instruct-v0.1
) 与包含用户查询和前三个表定义的提示一起呈现,并要求它编写一个适合任务的 SQL 查询。 - Mistral Instruct 生成一个 SQL 查询,我们将其对数据库运行,检索结果。
- SQL 查询结果被转换为 JSON,并与用户的原始查询、SQL 查询一起呈现在一个新提示中,并要求以自然语言为用户编写一个答案。
- Mistral Instruct 的自然语言文本响应将返回给用户。
数据库
在本教程中,我们使用了一个小型开放访问数据库,其中包含视频游戏销售记录,该数据库存储在 GitHub 上。我们将使用 SQLite 版本,因为 SQLite 非常紧凑、跨平台,并且具有内置的 Python 支持。
软件和硬件要求
我们将运行本地 Jina Reranker v2 模型。如果你使用 Google Colab 来运行此笔记本,请确保你使用的是可以访问 GPU 的运行时。如果你在本地运行它,你需要 Python 3(本教程使用 Python 3.11 安装编写)并且它在启用 CUDA 的 GPU 上运行起来会快得多。
在本教程中,我们还将广泛使用开源的 LlamaIndex RAG 框架,以及 Hugging Face 推理 API 来访问 Mistral 7B Instruct v0.1。你需要一个 Hugging Face 帐户 和一个 访问令牌,至少具有 READ
权限。
如果你使用的是 Google Colab,SQLite 已安装。它可能未安装在你的本地计算机上。如果未安装,请按照 SQLite 网站 上的说明进行安装。Python 接口代码内置于 Python 中,你无需为此安装任何 Python 模块。
设置
安装要求
首先,安装所需的 Python 模块。
!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank llama-index-llms-huggingface "huggingface_hub[inference]"
下载数据库
接下来,从 GitHub 下载 SQLite 数据库 videogames.db
到本地文件空间。如果你的系统上没有 wget
,请从 此链接 下载数据库,并将其放在运行此笔记本的同一目录中。
!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db
下载并运行 Jina Reranker v2
以下代码将下载模型 jina-reranker-v2-base-multilingual
并本地运行。
from transformers import AutoModelForSequenceClassification
reranker_model = AutoModelForSequenceClassification.from_pretrained(
"jinaai/jina-reranker-v2-base-multilingual",
torch_dtype="auto",
trust_remote_code=True,
)
reranker_model.to("cuda") # or 'cpu' if no GPU is available
reranker_model.eval()
设置与 Mistral Instruct 的接口
我们将使用 LlamaIndex 为与 Hugging Face 推理 API 的连接以及运行在那里的 mistralai/Mixtral-8x7B-Instruct-v0.1
副本创建一个占位符对象。
首先,从你的 Hugging Face 帐户设置页面 获取一个 Hugging Face 访问令牌。
在下面提示时输入它。
import getpass
print("Paste your Hugging Face access token here: ")
hf_token = getpass.getpass()
接下来,初始化 LlamaIndex 中的 HuggingFaceInferenceAPI
类的实例,并将其存储为 mistral_llm
。
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
mistral_llm = HuggingFaceInferenceAPI(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
使用 SQL 感知的 Jina Reranker v2
我们从 GitHub 上的数据库导入文件 中提取了八个表定义。运行下面的命令将其放入名为 table_declarations
的 Python 列表中。
table_declarations = [
"CREATE TABLE platform (\n\tid INTEGER PRIMARY KEY,\n\tplatform_name TEXT DEFAULT NULL\n);",
"CREATE TABLE genre (\n\tid INTEGER PRIMARY KEY,\n\tgenre_name TEXT DEFAULT NULL\n);",
"CREATE TABLE publisher (\n\tid INTEGER PRIMARY KEY,\n\tpublisher_name TEXT DEFAULT NULL\n);",
"CREATE TABLE region (\n\tid INTEGER PRIMARY KEY,\n\tregion_name TEXT DEFAULT NULL\n);",
"CREATE TABLE game (\n\tid INTEGER PRIMARY KEY,\n\tgenre_id INTEGER,\n\tgame_name TEXT DEFAULT NULL,\n\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\n);",
"CREATE TABLE game_publisher (\n\tid INTEGER PRIMARY KEY,\n\tgame_id INTEGER DEFAULT NULL,\n\tpublisher_id INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\n\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\n);",
"CREATE TABLE game_platform (\n\tid INTEGER PRIMARY KEY,\n\tgame_publisher_id INTEGER DEFAULT NULL,\n\tplatform_id INTEGER DEFAULT NULL,\n\trelease_year INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\n\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\n);",
"CREATE TABLE region_sales (\n\tregion_id INTEGER DEFAULT NULL,\n\tgame_platform_id INTEGER DEFAULT NULL,\n\tnum_sales REAL,\n CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\n\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\n);",
]
现在,我们定义一个函数,该函数接受自然语言查询和表定义列表,使用 Jina Reranker v2 对所有表定义进行评分,并按从最高评分到最低评分的顺序返回它们。
from typing import List, Tuple
def rank_tables(query: str, table_specs: List[str], top_n: int = 0) -> List[Tuple[float, str]]:
"""
Get sorted pairs of scores and table specifications, then return the top N,
or all if top_n is 0 or default.
"""
pairs = [[query, table_spec] for table_spec in table_specs]
scores = reranker_model.compute_score(pairs)
scored_tables = [(score, table_spec) for score, table_spec in zip(scores, table_specs)]
scored_tables.sort(key=lambda x: x[0], reverse=True)
if top_n and top_n < len(scored_tables):
return scored_tables[0:top_n]
return scored_tables
Jina Reranker v2 会对我们提供的每个表定义进行评分,默认情况下,此函数将返回所有表定义及其评分。可选参数 top_n
将返回结果的数量限制为用户定义的数量,从最高评分的结果开始。
试一试。首先,定义一个查询。
user_query = "Identify the top 10 platforms by total sales."
运行 rank_tables
以获得一个表定义列表。让我们将 top_n
设置为 3 以限制返回列表的大小并将其分配给变量 ranked_tables
,然后检查结果。
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
ranked_tables
输出应该包括 region_sales
、platform
和 game_platform
表,这些表似乎都是查找查询答案的合理位置。
使用 Mistral Instruct 生成 SQL
我们将让 Mistral Instruct v0.1 根据 Reranker 的前三张表的声明编写一个满足用户查询的 SQL 查询。
首先,我们使用 LlamaIndex 的 PromptTemplate
类为此目的创建一个提示。
from llama_index.core import PromptTemplate
make_sql_prompt_tmpl_text = """
Generate a SQL query to answer the following question from the user:
\"{query_str}\"
The SQL query should use only tables with the following SQL definitions:
Table 1:
{table_1}
Table 2:
{table_2}
Table 3:
{table_3}
Make sure you ONLY output an SQL query and no explanation.
"""
make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)
我们使用 format
方法用用户查询和 Jina Reranker v2 的前三张表声明来填充模板字段。
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
你可以看到我们将传递给 Mistral Instruct 的实际文本。
print(make_sql_prompt)
现在,让我们将提示发送到 Mistral Instruct 并检索其响应。
response = mistral_llm.complete(make_sql_prompt)
sql_query = str(response)
print(sql_query)
运行 SQL 查询
使用内置的 Python 接口到 SQLite 来运行上面的查询,针对数据库 videogames.db
。
import sqlite3
con = sqlite3.connect("videogames.db")
cur = con.cursor()
sql_response = cur.execute(sql_query).fetchall()
有关 SQLite 接口的详细信息,请参阅 Python3 文档。
检查结果。
sql_response
你可以通过运行自己的 SQL 查询来检查结果是否正确。存储在此数据库中的销售数据以浮点数形式表示,可能是数千或数百万的销量。
获取自然语言答案
现在,我们将使用新的提示模板将用户的查询、SQL 查询和结果传递回 Mistral Instruct。
首先,使用 LlamaIndex 创建新的提示模板,与上面一样
rag_prompt_tmpl_str = """
Use the information in the JSON table to answer the following user query.
Do not explain anything, just answer concisely. Use natural language in your
answer, not computer formatting.
USER QUERY: {query_str}
JSON table:
{json_table}
This table was generated by the following SQL query:
{sql_query}
Answer ONLY using the information in the table and the SQL query, and if the
table does not provide the information to answer the question, answer
"No Information".
"""
rag_prompt_tmpl = PromptTemplate(rag_prompt_tmpl_str)
我们将把 SQL 输出转换为 JSON,这是一种 Mistral Instruct v0.1 理解的格式。
填充模板字段
import json
rag_prompt = rag_prompt_tmpl.format(
query_str="Identify the top 10 platforms by total sales", json_table=json.dumps(sql_response), sql_query=sql_query
)
现在从 Mistral Instruct 获取自然语言响应
rag_response = mistral_llm.complete(rag_prompt)
print(str(rag_response))
自己尝试
让我们将所有这些整理成一个带有异常捕获的功能
def answer_sql(user_query: str) -> str:
try:
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
except Exception as e:
print(f"Ranking failed.\nUser query:\n{user_query}\n\n")
raise (e)
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
try:
response = mistral_llm.complete(make_sql_prompt)
except Exception as e:
print(f"SQL query generation failed\nPrompt:\n{make_sql_prompt}\n\n")
raise (e)
# Backslash removal is a necessary hack because sometimes Mistral puts them
# in its generated code.
sql_query = str(response).replace("\\", "")
try:
sql_response = sqlite3.connect("videogames.db").cursor().execute(sql_query).fetchall()
except Exception as e:
print(f"SQL querying failed. Query:\n{sql_query}\n\n")
raise (e)
rag_prompt = rag_prompt_tmpl.format(query_str=user_query, json_table=json.dumps(sql_response), sql_query=sql_query)
try:
rag_response = mistral_llm.complete(rag_prompt)
return str(rag_response)
except Exception as e:
print(f"Answer generation failed. Prompt:\n{rag_prompt}\n\n")
raise (e)
试试看
print(answer_sql("Identify the top 10 platforms by total sales."))
尝试其他一些查询
print(answer_sql("Summarize sales by region."))
print(answer_sql("List the publisher with the largest number of published games."))
print(answer_sql("Display the year with most games released."))
print(answer_sql("What is the most popular game genre on the Wii platform?"))
print(answer_sql("What is the most popular game genre of 2012?"))
尝试您自己的查询
print(answer_sql("<INSERT QUESTION OR INSTRUCTION HERE>"))
回顾和结论
我们已经向您展示了如何为自然语言问答创建一个非常基本的 RAG(检索增强生成)系统,该系统使用 SQL 数据库作为信息源。在此实现中,我们使用相同的语言模型(Mistral Instruct v0.1)来生成 SQL 查询并构建自然语言响应。
这里的数据库是一个非常小的示例,将其扩展可能需要比仅仅对表格定义列表进行排序更复杂的方法。您可能希望使用两阶段流程,其中嵌入模型和向量存储最初检索更多结果,但重新排序模型会将其缩减到您可以放入生成语言模型提示中的任何数量。
此笔记本假设没有任何请求需要超过三个表格来满足,显然,在实践中,这并不总是正确的。Mistral 7B Instruct v0.1 并不保证生成正确的(甚至可执行的)SQL 输出。在生产环境中,类似于这种操作需要更深入的错误处理。
更复杂的错误处理、更长的输入上下文窗口以及专门用于 SQL 特定任务的生成模型可能会在实际应用中产生很大影响。
尽管如此,您可以在此处看到 RAG 概念如何扩展到结构化数据库,从而极大地扩展了其使用范围。
< > 在 GitHub 上更新