开源 AI 食谱文档
RAG 由 SQL 和 Jina Reranker v2 支持
并获得增强的文档体验
开始使用
RAG 由 SQL 和 Jina Reranker v2 支持
作者:Scott Martens @ Jina AI
本笔记本将向您展示如何制作一个简单的检索增强生成 (RAG) 系统,该系统利用 SQL 数据库而不是从文档存储中提取信息。
工作原理
- 给定一个 SQL 数据库,我们提取 SQL 表定义(SQL 转储中的
CREATE
行)并将它们存储起来。在本教程中,我们为您完成了这部分工作,定义以列表形式存储在内存中。从本示例进行扩展可能需要更复杂的存储。 - 用户输入自然语言查询。
- Jina Reranker v2 (
jinaai/jina-reranker-v2-base-multilingual
),来自 Jina AI 的 SQL 感知重排序模型,按表定义与用户查询的相关性对其进行排序。 - 我们向 Mistral 7B Instruct v0.1 (
mistralai/Mistral-7B-Instruct-v0.1
) 提供一个提示,其中包含用户的查询和前三个表定义,并请求编写一个 SQL 查询来适应任务。 - Mistral Instruct 生成一个 SQL 查询,我们针对数据库运行它,检索结果。
- SQL 查询结果转换为 JSON 并呈现在 Mistral Instruct 的新提示中,以及用户的原始查询、SQL 查询,并请求用自然语言为用户撰写答案。
- Mistral Instruct 的自然语言文本响应将返回给用户。
数据库
在本教程中,我们使用了一个小型开放访问的视频游戏销售记录数据库,存储在 GitHub 上。我们将使用 SQLite 版本,因为 SQLite 非常紧凑、跨平台并且具有内置的 Python 支持。
软件和硬件要求
我们将本地运行 Jina Reranker v2 模型。如果您使用 Google Colab 运行此笔记本,请确保您使用的运行时可以访问 GPU。如果您在本地运行它,您将需要 Python 3(本教程是使用 Python 3.11 安装编写的),并且使用 CUDA 启用的 GPU 运行速度会快得多。
我们还将在本教程中广泛使用开源 LlamaIndex RAG 框架,以及 Hugging Face 推理 API 来访问 Mistral 7B Instruct v0.1。您将需要一个 Hugging Face 帐户和一个至少具有 READ
访问权限的 访问令牌。
如果您使用 Google Colab,则已安装 SQLite。它可能未安装在您的本地计算机上。如果未安装,请按照 SQLite 网站上的说明进行安装。Python 接口代码内置于 Python 中,您无需为其安装任何 Python 模块。
设置
安装要求
首先,安装所需的 Python 模块
!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank llama-index-llms-huggingface "huggingface_hub[inference]"
下载数据库
接下来,从 GitHub 将 SQLite 数据库 videogames.db
下载到本地文件空间。如果您的系统上没有 wget
,请从此链接下载数据库,并将其放在您运行此笔记本的同一目录中
!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db
下载并运行 Jina Reranker v2
以下代码将下载模型 jina-reranker-v2-base-multilingual
并在本地运行它
from transformers import AutoModelForSequenceClassification
reranker_model = AutoModelForSequenceClassification.from_pretrained(
"jinaai/jina-reranker-v2-base-multilingual",
torch_dtype="auto",
trust_remote_code=True,
)
reranker_model.to("cuda") # or 'cpu' if no GPU is available
reranker_model.eval()
设置与 Mistral Instruct 的接口
我们将使用 LlamaIndex 创建一个持有者对象,用于连接到 Hugging Face 推理 API 和在那里运行的 mistralai/Mixtral-8x7B-Instruct-v0.1
副本。
首先,从您的 Hugging Face 帐户设置页面获取 Hugging Face 访问令牌。
在下面提示时输入它
import getpass
print("Paste your Hugging Face access token here: ")
hf_token = getpass.getpass()
接下来,从 LlamaIndex 初始化 HuggingFaceInferenceAPI
类的实例,并将其存储为 mistral_llm
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
mistral_llm = HuggingFaceInferenceAPI(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
使用 SQL 感知 Jina Reranker v2
我们从 GitHub 上的数据库导入文件中提取了八个表定义。运行以下命令将它们放入名为 table_declarations
的 Python 列表中
table_declarations = [
"CREATE TABLE platform (\n\tid INTEGER PRIMARY KEY,\n\tplatform_name TEXT DEFAULT NULL\n);",
"CREATE TABLE genre (\n\tid INTEGER PRIMARY KEY,\n\tgenre_name TEXT DEFAULT NULL\n);",
"CREATE TABLE publisher (\n\tid INTEGER PRIMARY KEY,\n\tpublisher_name TEXT DEFAULT NULL\n);",
"CREATE TABLE region (\n\tid INTEGER PRIMARY KEY,\n\tregion_name TEXT DEFAULT NULL\n);",
"CREATE TABLE game (\n\tid INTEGER PRIMARY KEY,\n\tgenre_id INTEGER,\n\tgame_name TEXT DEFAULT NULL,\n\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\n);",
"CREATE TABLE game_publisher (\n\tid INTEGER PRIMARY KEY,\n\tgame_id INTEGER DEFAULT NULL,\n\tpublisher_id INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\n\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\n);",
"CREATE TABLE game_platform (\n\tid INTEGER PRIMARY KEY,\n\tgame_publisher_id INTEGER DEFAULT NULL,\n\tplatform_id INTEGER DEFAULT NULL,\n\trelease_year INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\n\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\n);",
"CREATE TABLE region_sales (\n\tregion_id INTEGER DEFAULT NULL,\n\tgame_platform_id INTEGER DEFAULT NULL,\n\tnum_sales REAL,\n CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\n\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\n);",
]
现在,我们定义一个函数,该函数接受自然语言查询和表定义列表,使用 Jina Reranker v2 对所有表定义进行评分,并按从最高分到最低分的顺序返回它们
from typing import List, Tuple
def rank_tables(query: str, table_specs: List[str], top_n: int = 0) -> List[Tuple[float, str]]:
"""
Get sorted pairs of scores and table specifications, then return the top N,
or all if top_n is 0 or default.
"""
pairs = [[query, table_spec] for table_spec in table_specs]
scores = reranker_model.compute_score(pairs)
scored_tables = [(score, table_spec) for score, table_spec in zip(scores, table_specs)]
scored_tables.sort(key=lambda x: x[0], reverse=True)
if top_n and top_n < len(scored_tables):
return scored_tables[0:top_n]
return scored_tables
Jina Reranker v2 对我们提供的每个表定义进行评分,默认情况下,此函数将返回所有表定义及其分数。可选参数 top_n
将返回的结果数限制为用户定义的数字,从得分最高的开始。
试用一下。首先,定义一个查询
user_query = "Identify the top 10 platforms by total sales."
运行 rank_tables
以获取表定义列表。让我们将 top_n
设置为 3 以限制返回列表的大小,并将其分配给变量 ranked_tables
,然后检查结果
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
ranked_tables
输出应包括表 region_sales
、platform
和 game_platform
,这些表似乎都是查找查询答案的合理位置。
使用 Mistral Instruct 生成 SQL
我们将让 Mistral Instruct v0.1 根据重排序器排名前三的表的声明,编写一个满足用户查询的 SQL 查询。
首先,我们使用 LlamaIndex 的 PromptTemplate
类为此目的创建一个提示
from llama_index.core import PromptTemplate
make_sql_prompt_tmpl_text = """
Generate a SQL query to answer the following question from the user:
\"{query_str}\"
The SQL query should use only tables with the following SQL definitions:
Table 1:
{table_1}
Table 2:
{table_2}
Table 3:
{table_3}
Make sure you ONLY output an SQL query and no explanation.
"""
make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)
我们使用 format
方法使用 Jina Reranker v2 中的用户查询和前三个表声明填充模板字段
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
您可以看到我们将要传递给 Mistral Instruct 的实际文本
print(make_sql_prompt)
现在让我们将提示发送到 Mistral Instruct 并检索其响应
response = mistral_llm.complete(make_sql_prompt)
sql_query = str(response)
print(sql_query)
运行 SQL 查询
使用内置的 Python 接口到 SQLite,针对数据库 videogames.db
运行上述查询
import sqlite3
con = sqlite3.connect("videogames.db")
cur = con.cursor()
sql_response = cur.execute(sql_query).fetchall()
有关 SQLite 接口的详细信息,请参阅 Python3 文档。
检查结果
sql_response
您可以通过运行自己的 SQL 查询来检查这是否正确。存储在此数据库中的销售数据采用浮点数形式,大概是数千或数百万的单位销量。
获取自然语言答案
现在,我们将用户的查询、SQL 查询和结果与新的提示模板一起传递回 Mistral Instruct。
首先,使用 LlamaIndex 创建新的提示模板,与上面相同
rag_prompt_tmpl_str = """
Use the information in the JSON table to answer the following user query.
Do not explain anything, just answer concisely. Use natural language in your
answer, not computer formatting.
USER QUERY: {query_str}
JSON table:
{json_table}
This table was generated by the following SQL query:
{sql_query}
Answer ONLY using the information in the table and the SQL query, and if the
table does not provide the information to answer the question, answer
"No Information".
"""
rag_prompt_tmpl = PromptTemplate(rag_prompt_tmpl_str)
我们将 SQL 输出转换为 Mistral Instruct v0.1 理解的 JSON 格式。
填充模板字段
import json
rag_prompt = rag_prompt_tmpl.format(
query_str="Identify the top 10 platforms by total sales", json_table=json.dumps(sql_response), sql_query=sql_query
)
现在向 Mistral Instruct 请求自然语言响应
rag_response = mistral_llm.complete(rag_prompt)
print(str(rag_response))
自己尝试
让我们将所有这些组织到一个具有异常捕获的函数中
def answer_sql(user_query: str) -> str:
try:
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
except Exception as e:
print(f"Ranking failed.\nUser query:\n{user_query}\n\n")
raise (e)
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
try:
response = mistral_llm.complete(make_sql_prompt)
except Exception as e:
print(f"SQL query generation failed\nPrompt:\n{make_sql_prompt}\n\n")
raise (e)
# Backslash removal is a necessary hack because sometimes Mistral puts them
# in its generated code.
sql_query = str(response).replace("\\", "")
try:
sql_response = sqlite3.connect("videogames.db").cursor().execute(sql_query).fetchall()
except Exception as e:
print(f"SQL querying failed. Query:\n{sql_query}\n\n")
raise (e)
rag_prompt = rag_prompt_tmpl.format(query_str=user_query, json_table=json.dumps(sql_response), sql_query=sql_query)
try:
rag_response = mistral_llm.complete(rag_prompt)
return str(rag_response)
except Exception as e:
print(f"Answer generation failed. Prompt:\n{rag_prompt}\n\n")
raise (e)
试用一下
print(answer_sql("Identify the top 10 platforms by total sales."))
尝试其他一些查询
print(answer_sql("Summarize sales by region."))
print(answer_sql("List the publisher with the largest number of published games."))
print(answer_sql("Display the year with most games released."))
print(answer_sql("What is the most popular game genre on the Wii platform?"))
print(answer_sql("What is the most popular game genre of 2012?"))
尝试您自己的查询
print(answer_sql("<INSERT QUESTION OR INSTRUCTION HERE>"))
回顾与结论
我们向您展示了如何构建一个非常基础的 RAG(检索增强生成)系统,用于自然语言问答,该系统使用 SQL 数据库作为信息来源。在这个实现中,我们使用同一个大型语言模型 (Mistral Instruct v0.1) 来生成 SQL 查询并构建自然语言响应。
这里的数据库是一个非常小的示例,扩展它可能需要比仅对表定义列表进行排序更复杂的方法。您可能需要使用两阶段过程,其中嵌入模型和向量存储最初检索更多结果,但重新排序模型会将结果修剪为您能够放入生成语言模型提示中的任何数量。
本笔记本假设没有请求需要超过三个表来满足,显然,在实践中,情况并非总是如此。Mistral 7B Instruct v0.1 不能保证生成正确(甚至可执行)的 SQL 输出。在生产环境中,类似这样的情况需要更深入的错误处理。
更复杂的错误处理、更长的输入上下文窗口以及专门用于 SQL 特定任务的生成模型可能会在实际应用中产生重大影响。
尽管如此,您可以在这里看到 RAG 概念如何扩展到结构化数据库,从而极大地扩展了其使用范围。
< > 在 GitHub 上更新