开源 AI 食谱文档

由 SQL 和 Jina Reranker v2 支持的 RAG

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

由 SQL 和 Jina Reranker v2 支持的 RAG

作者:Scott Martens @ Jina AI

本笔记本将向您展示如何制作一个简单的检索增强生成 (RAG) 系统,该系统从 SQL 数据库而不是从文档库中提取信息。

工作原理

  • 给定一个 SQL 数据库,我们提取 SQL 表定义 (SQL 转储中的 `CREATE` 行) 并存储它们。在本教程中,我们已经为您完成了这一部分,定义以列表形式存储在内存中。从本示例扩展可能需要更复杂的存储方式。
  • 用户以自然语言输入一个查询。
  • Jina Reranker v2 (`jinaai/jina-reranker-v2-base-multilingual`),一个来自 Jina AI 的支持 SQL 的重排模型,会根据与用户查询的相关性对表定义进行排序。
  • 我们向 Mistral 7B Instruct v0.1 (`mistralai/Mistral-7B-Instruct-v0.1`) 提供一个提示,其中包含用户的查询和排名前三的表定义,并请求它编写一个 SQL 查询来完成任务。
  • Mistral Instruct 生成一个 SQL 查询,我们对数据库执行该查询,检索结果。
  • SQL 查询结果被转换为 JSON 格式,并与用户的原始查询、SQL 查询一起在一个新的提示中提供给 Mistral Instruct,并请求它用自然语言为用户撰写一个答案。
  • Mistral Instruct 的自然语言文本响应将返回给用户。

数据库

在本教程中,我们使用一个关于视频游戏销售记录的小型开放访问数据库,该数据库存储在 GitHub 上。我们将使用 SQLite 版本,因为 SQLite 非常紧凑、跨平台,并且内置了 Python 支持。

软硬件要求

我们将在本地运行 Jina Reranker v2 模型。如果您使用 Google Colab 运行此笔记本,请确保您使用的运行时可以访问 GPU。如果您在本地运行,您将需要 Python 3 (本教程使用 Python 3.11 版本编写),并且在有支持 CUDA 的 GPU 的情况下运行速度会*快得多*。

在本教程中,我们还将广泛使用开源的 LlamaIndex RAG 框架,以及 Hugging Face 推理 API 来访问 Mistral 7B Instruct v0.1。您将需要一个 Hugging Face 账户 和一个具有至少 `READ` 访问权限的访问令牌

如果您使用 Google Colab,SQLite 已经安装好了。它可能没有安装在您的本地计算机上。如果尚未安装,请按照 SQLite 网站上的说明进行安装。Python 接口代码内置于 Python 中,您无需为其安装任何 Python 模块。

环境设置

安装依赖

首先,安装所需的 Python 模块

!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank  llama-index-llms-huggingface "huggingface_hub[inference]"

下载数据库

接下来,从 GitHub 将 SQLite 数据库 `videogames.db` 下载到本地文件空间。如果您的系统上没有 `wget`,请从此链接下载数据库,并将其放在运行此笔记本的同一目录中。

!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db

下载并运行 Jina Reranker v2

以下代码将下载 `jina-reranker-v2-base-multilingual` 模型并在本地运行。

from transformers import AutoModelForSequenceClassification

reranker_model = AutoModelForSequenceClassification.from_pretrained(
    "jinaai/jina-reranker-v2-base-multilingual",
    torch_dtype="auto",
    trust_remote_code=True,
)

reranker_model.to("cuda")  # or 'cpu' if no GPU is available
reranker_model.eval()

设置 Mistral Instruct 接口

我们将使用 LlamaIndex 创建一个持有者对象,用于连接到 Hugging Face 推理 API 以及在那里运行的 `mistralai/Mixtral-8x7B-Instruct-v0.1` 副本。

首先,从您的 Hugging Face 账户设置页面获取一个 Hugging Face 访问令牌。

在下方提示时输入它

import getpass

print("Paste your Hugging Face access token here: ")
hf_token = getpass.getpass()

接下来,初始化一个 LlamaIndex 的 `HuggingFaceInferenceAPI` 类的实例,并将其存储为 `mistral_llm`。

from llama_index.llms.huggingface import HuggingFaceInferenceAPI

mistral_llm = HuggingFaceInferenceAPI(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)

使用支持 SQL 的 Jina Reranker v2

我们从位于 GitHub 上的数据库导入文件中提取了八个表定义。运行以下命令将它们放入一个名为 `table_declarations` 的 Python 列表中。

table_declarations = [
    "CREATE TABLE platform (\n\tid INTEGER PRIMARY KEY,\n\tplatform_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE genre (\n\tid INTEGER PRIMARY KEY,\n\tgenre_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE publisher (\n\tid INTEGER PRIMARY KEY,\n\tpublisher_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE region (\n\tid INTEGER PRIMARY KEY,\n\tregion_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE game (\n\tid INTEGER PRIMARY KEY,\n\tgenre_id INTEGER,\n\tgame_name TEXT DEFAULT NULL,\n\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\n);",
    "CREATE TABLE game_publisher (\n\tid INTEGER PRIMARY KEY,\n\tgame_id INTEGER DEFAULT NULL,\n\tpublisher_id INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\n\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\n);",
    "CREATE TABLE game_platform (\n\tid INTEGER PRIMARY KEY,\n\tgame_publisher_id INTEGER DEFAULT NULL,\n\tplatform_id INTEGER DEFAULT NULL,\n\trelease_year INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\n\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\n);",
    "CREATE TABLE region_sales (\n\tregion_id INTEGER DEFAULT NULL,\n\tgame_platform_id INTEGER DEFAULT NULL,\n\tnum_sales REAL,\n   CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\n\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\n);",
]

现在,我们定义一个函数,该函数接受自然语言查询和表定义列表,使用 Jina Reranker v2 对所有表定义进行评分,并按从最高分到最低分的顺序返回它们。

from typing import List, Tuple


def rank_tables(query: str, table_specs: List[str], top_n: int = 0) -> List[Tuple[float, str]]:
    """
    Get sorted pairs of scores and table specifications, then return the top N,
    or all if top_n is 0 or default.
    """
    pairs = [[query, table_spec] for table_spec in table_specs]
    scores = reranker_model.compute_score(pairs)
    scored_tables = [(score, table_spec) for score, table_spec in zip(scores, table_specs)]
    scored_tables.sort(key=lambda x: x[0], reverse=True)
    if top_n and top_n < len(scored_tables):
        return scored_tables[0:top_n]
    return scored_tables

Jina Reranker v2 会为我们提供的每个表定义评分,默认情况下,此函数将返回所有表定义及其分数。可选参数 `top_n` 将返回结果的数量限制为用户定义的数量,从得分最高的结果开始。

来试试看。首先,定义一个查询。

user_query = "Identify the top 10 platforms by total sales."

运行 `rank_tables` 以返回一个表定义列表。我们将 `top_n` 设置为 3 以限制返回列表的大小,并将其分配给变量 `ranked_tables`,然后检查结果。

ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
ranked_tables

输出应包括 `region_sales`、`platform` 和 `game_platform` 这几张表,它们似乎都是查找查询答案的合理位置。

使用 Mistral Instruct 生成 SQL

我们将让 Mistral Instruct v0.1 根据重排器给出的前三个表的声明,编写一个 SQL 查询来满足用户的查询。

首先,我们使用 LlamaIndex 的 `PromptTemplate` 类为此目的创建一个提示。

from llama_index.core import PromptTemplate

make_sql_prompt_tmpl_text = """
Generate a SQL query to answer the following question from the user:
\"{query_str}\"

The SQL query should use only tables with the following SQL definitions:

Table 1:
{table_1}

Table 2:
{table_2}

Table 3:
{table_3}

Make sure you ONLY output an SQL query and no explanation.
"""
make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)

我们使用 `format` 方法来填充模板字段,包括用户查询和来自 Jina Reranker v2 的前三个表声明。

make_sql_prompt = make_sql_prompt_tmpl.format(
    query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)

您可以看到我们将要传递给 Mistral Instruct 的实际文本。

print(make_sql_prompt)

现在,让我们将提示发送给 Mistral Instruct 并检索其响应。

response = mistral_llm.complete(make_sql_prompt)
sql_query = str(response)
print(sql_query)

运行 SQL 查询

使用内置的 Python SQLite 接口对数据库 `videogames.db` 运行上述查询。

import sqlite3

con = sqlite3.connect("videogames.db")
cur = con.cursor()
sql_response = cur.execute(sql_query).fetchall()

有关 SQLite 接口的详细信息,请参阅 Python3 文档

检查结果

sql_response

您可以通过运行自己的 SQL 查询来检查这是否正确。此数据库中存储的销售数据是浮点数形式,大概是成千上万或数百万的单位销量。

获得自然语言答案

现在,我们将把用户的查询、SQL 查询和结果连同一个新的提示模板一起传回给 Mistral Instruct。

首先,使用 LlamaIndex 创建新的提示模板,方法同上。

rag_prompt_tmpl_str = """
Use the information in the JSON table to answer the following user query.
Do not explain anything, just answer concisely. Use natural language in your
answer, not computer formatting.

USER QUERY: {query_str}

JSON table:
{json_table}

This table was generated by the following SQL query:
{sql_query}

Answer ONLY using the information in the table and the SQL query, and if the
table does not provide the information to answer the question, answer
"No Information".
"""
rag_prompt_tmpl = PromptTemplate(rag_prompt_tmpl_str)

我们将把 SQL 输出转换为 JSON 格式,这是 Mistral Instruct v0.1 能理解的格式。

填充模板字段

import json

rag_prompt = rag_prompt_tmpl.format(
    query_str="Identify the top 10 platforms by total sales", json_table=json.dumps(sql_response), sql_query=sql_query
)

现在从 Mistral Instruct 请求一个自然语言响应

rag_response = mistral_llm.complete(rag_prompt)
print(str(rag_response))

自己动手试试

让我们把所有这些组织成一个带有异常捕获的函数。

def answer_sql(user_query: str) -> str:
    try:
        ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
    except Exception as e:
        print(f"Ranking failed.\nUser query:\n{user_query}\n\n")
        raise (e)

    make_sql_prompt = make_sql_prompt_tmpl.format(
        query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
    )

    try:
        response = mistral_llm.complete(make_sql_prompt)
    except Exception as e:
        print(f"SQL query generation failed\nPrompt:\n{make_sql_prompt}\n\n")
        raise (e)

    # Backslash removal is a necessary hack because sometimes Mistral puts them
    # in its generated code.
    sql_query = str(response).replace("\\", "")

    try:
        sql_response = sqlite3.connect("videogames.db").cursor().execute(sql_query).fetchall()
    except Exception as e:
        print(f"SQL querying failed. Query:\n{sql_query}\n\n")
        raise (e)

    rag_prompt = rag_prompt_tmpl.format(query_str=user_query, json_table=json.dumps(sql_response), sql_query=sql_query)
    try:
        rag_response = mistral_llm.complete(rag_prompt)
        return str(rag_response)
    except Exception as e:
        print(f"Answer generation failed. Prompt:\n{rag_prompt}\n\n")
        raise (e)

试试看吧

print(answer_sql("Identify the top 10 platforms by total sales."))

尝试一些其他查询

print(answer_sql("Summarize sales by region."))
print(answer_sql("List the publisher with the largest number of published games."))
print(answer_sql("Display the year with most games released."))
print(answer_sql("What is the most popular game genre on the Wii platform?"))
print(answer_sql("What is the most popular game genre of 2012?"))

尝试你自己的查询

print(answer_sql("<INSERT QUESTION OR INSTRUCTION HERE>"))

回顾与总结

我们向您展示了如何制作一个非常基础的 RAG (检索增强生成) 系统,用于自然语言问答,该系统使用 SQL 数据库作为信息源。在此实现中,我们使用同一个大型语言模型 (Mistral Instruct v0.1) 来生成 SQL 查询和构建自然语言响应。

这里的数据库是一个非常小的示例,要扩展此系统可能需要比仅仅对表定义列表进行排序更复杂的方法。您可能希望使用一个两阶段过程,其中嵌入模型和向量存储库首先检索更多结果,但重排模型会将其筛选到您能够放入生成语言模型提示中的数量。

本笔记本假设任何请求都不需要超过三张表来满足,显然,在实践中,这不可能总是成立。Mistral 7B Instruct v0.1 不保证产生正确 (甚至可执行) 的 SQL 输出。在生产环境中,这样的系统需要更深入的错误处理。

更复杂的错误处理、更长的输入上下文窗口,以及专门用于 SQL 特定任务的生成模型,可能会在实际应用中产生巨大差异。

尽管如此,您可以在这里看到 RAG 概念如何扩展到结构化数据库,从而极大地扩展了其使用范围。

< > 在 GitHub 上更新