smolagents 文档

Agentic RAG

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

检索增强生成 (RAG) 是“使用 LLM 回答用户查询,但答案基于从知识库检索的信息”。与使用原始或微调的 LLM 相比,它有许多优点:仅举几例,它允许将答案建立在真实事实上并减少虚构,它允许为 LLM 提供特定领域的知识,并且它允许对知识库信息的访问进行细粒度控制。

但原始 RAG 存在局限性,最重要的两个局限性如下:

  • 它仅执行一个检索步骤:如果结果不佳,那么生成的结果也会不佳。
  • 语义相似度是使用用户查询作为参考计算的,这可能不是最优的:例如,用户查询通常是一个问题,而包含真实答案的文档将以肯定语气表达,因此与其他疑问形式的源文档相比,其相似度得分将被降低,从而导致错过相关信息的风险。

我们可以通过创建一个 RAG Agent 来缓解这些问题:非常简单,一个配备了检索器工具的 Agent!

这个 Agent 将会:✅ 自行制定查询,并且 ✅ 在需要时进行批判性重检索。

因此,它应该能够朴素地恢复一些高级 RAG 技术!

  • Agent 不是直接使用用户查询作为语义搜索的参考,而是自行制定一个更接近目标文档的参考句子,如 HyDE 中所示。Agent 可以使用生成的代码片段并在需要时重新检索,如 Self-Query 中所示。

让我们构建这个系统。🛠️

运行以下行以安装所需的依赖项

!pip install smolagents pandas langchain langchain-community sentence-transformers datasets python-dotenv rank_bm25 --upgrade -q

要调用 HF Inference API,您将需要一个有效的令牌作为您的环境变量 HF_TOKEN。我们使用 python-dotenv 来加载它。

from dotenv import load_dotenv
load_dotenv()

我们首先加载一个知识库,我们希望在其上执行 RAG:此数据集是许多 Hugging Face 库的文档页面的汇编,以 markdown 格式存储。我们将仅保留 transformers 库的文档。

然后通过处理数据集并将其存储到向量数据库中来准备知识库,以供检索器使用。

我们使用 LangChain,因为它具有出色的向量数据库实用程序。

import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever

knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))

source_docs = [
    Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
    for doc in knowledge_base
]

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
    add_start_index=True,
    strip_whitespace=True,
    separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)

现在文档已准备就绪。

那么,让我们构建我们的 Agentic RAG 系统!

👉 我们只需要一个 RetrieverTool,我们的 Agent 可以利用它从知识库中检索信息。

由于我们需要将 vectordb 添加为工具的属性,因此我们不能简单地使用带有 @tool 装饰器的简单工具构造函数:因此,我们将遵循 工具教程 中突出显示的高级设置。

from smolagents import Tool

class RetrieverTool(Tool):
    name = "retriever"
    description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        }
    }
    output_type = "string"

    def __init__(self, docs, **kwargs):
        super().__init__(**kwargs)
        self.retriever = BM25Retriever.from_documents(
            docs, k=10
        )

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        docs = self.retriever.invoke(
            query,
        )
        return "\nRetrieved documents:\n" + "".join(
            [
                f"\n\n===== Document {str(i)} =====\n" + doc.page_content
                for i, doc in enumerate(docs)
            ]
        )

retriever_tool = RetrieverTool(docs_processed)

我们使用了 BM25,这是一种经典的检索方法,因为它设置速度极快。为了提高检索准确率,您可以将 BM25 替换为使用文档向量表示的语义搜索:因此,您可以前往 MTEB 排行榜 选择一个好的嵌入模型。

现在,创建一个利用此 retriever_tool 的 Agent 就很简单了!

Agent 在初始化时需要这些参数

  • tools:Agent 可以调用的工具列表。
  • model:驱动 Agent 的 LLM。我们的 model 必须是可调用的,它接受消息列表作为输入并返回文本。它还需要接受 stop_sequences 参数,该参数指示何时停止其生成。为了方便起见,我们直接使用包中提供的 HfEngine 类来获取调用 Hugging Face Inference API 的 LLM 引擎。

[!NOTE] 要使用特定模型,请像这样传递它: HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")。Inference API 托管基于各种标准的模型,并且已部署的模型可能会在没有事先通知的情况下进行更新或替换。在此处了解更多信息:here

from smolagents import HfApiModel, CodeAgent

agent = CodeAgent(
    tools=[retriever_tool], model=HfApiModel(), max_steps=4, verbosity_level=2
)

在初始化 CodeAgent 时,它已被自动赋予默认系统提示,该提示告诉 LLM 引擎逐步处理并生成作为代码片段的工具调用,但您可以根据需要将此提示模板替换为您自己的模板。

然后,当其 .run() 方法启动时,Agent 负责调用 LLM 引擎并执行工具调用,所有这些都在一个循环中进行,该循环仅在工具 final_answer 以最终答案作为其参数被调用时结束。

agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")

print("Final output:")
print(agent_output)