开源 AI 食谱文档

使用结构化生成实现带来源高亮的 RAG

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

使用结构化生成实现带来源高亮的 RAG

作者:Aymeric Roucher

结构化生成是一种强制 LLM 输出遵循特定约束(例如遵循特定模式)的方法。

这有许多用例

  • ✅ 输出一个包含特定键的字典
  • 📏 确保输出长度超过 N 个字符
  • ⚙️ 更一般地,强制输出遵循某个正则表达式模式,以便进行下游处理。
  • 💡 在检索增强生成(RAG)中高亮显示支持答案的来源

在本笔记本中,我们将具体演示最后一个用例

➡️ 我们将构建一个 RAG 系统,它不仅提供答案,还能高亮显示该答案所依据的支持性片段。

如果您需要 RAG 的入门介绍,可以查看这篇其他指南

本笔记本首先展示了一种通过提示进行结构化生成的朴素方法并指出了其局限性,然后演示了使用受约束解码来实现更高效的结构化生成。

它利用了 HuggingFace 推理端点(示例展示了一个无服务器端点,但您可以直接将端点更改为专用端点),然后还展示了一个使用结构化文本生成库 outlines 的本地流水线。

!pip install pandas json huggingface_hub pydantic outlines accelerate -q
import pandas as pd
import json
from huggingface_hub import InferenceClient

pd.set_option("display.max_colwidth", None)
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"

llm_client = InferenceClient(model=repo_id, timeout=120)

# Test your LLM client
llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)

提示模型

要从模型中获得结构化输出,您可以简单地向一个足够强大的模型提供适当的指导提示,它应该能直接工作……在大多数情况下是这样。

在这种情况下,我们希望 RAG 模型不仅生成答案,还生成一个置信度分数和一些来源片段。我们希望以 JSON 字典的形式生成这些内容,以便后续可以轻松解析进行下游处理(这里我们只会高亮显示来源片段)。

RELEVANT_CONTEXT = """
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

"""
RAG_PROMPT_TEMPLATE_JSON = """
Answer the user query based on the source documents.

Here are the source documents: {context}


You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}}
End of answer.

Now begin!
Here is the user question: {user_query}.
Answer:
"""
USER_QUERY = "How can I define a stop sequence in Transformers?"
>>> prompt = RAG_PROMPT_TEMPLATE_JSON.format(context=RELEVANT_CONTEXT, user_query=USER_QUERY)
>>> print(prompt)
Answer the user query based on the source documents.

Here are the source documents: 
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.




You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}
End of answer.

Now begin!
Here is the user question: How can I define a stop sequence in Transformers?.
Answer:
>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=1000,
... )

>>> answer = answer.split("End of answer.")[0]
>>> print(answer)
{
  "answer": "You should pass the stop_sequence argument in your pipeline or model.",
  "confidence_score": 0.9,
  "source_snippets": ["stop_sequence", "pipeline or model"]
}

LLM 的输出是字典的字符串表示:所以我们只需使用 literal_eval 将其加载为字典。

from ast import literal_eval

parsed_answer = literal_eval(answer)
>>> def highlight(s):
...     return "\x1b[1;32m" + s + "\x1b[0m"


>>> def print_results(answer, source_text, highlight_snippets):
...     print("Answer:", highlight(answer))
...     print("\n\n", "=" * 10 + " Source documents " + "=" * 10)
...     for snippet in highlight_snippets:
...         source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))
...     print(source_text)


>>> print_results(parsed_answer["answer"], RELEVANT_CONTEXT, parsed_answer["source_snippets"])
Answer: You should pass the stop_sequence argument in your pipeline or model.


 ========== Source documents ==========

Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

这行得通!🥳

但是如果使用一个能力较弱的模型呢?

为了模拟一个能力较弱的模型可能产生的不那么连贯的输出,我们增加了温度(temperature)。

>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)
{
  "answer": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes	Render“He caτων不是來rates‏ 그런Received05jet �	DECLAREed "]";
Top Access臣Zen PastFlow.TabBand                                                
.Assquoas 믿锦encers relativ巨 durations........ $块 leftイStaffuddled/HlibBR、【(cardospelrowth)\<午…)_SHADERprovided["_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro	dis}')
———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,
_FOR NI/disable зн 主posureWiders,latRU_BUSY&#123;amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th	send-per '
nicasv:<:',
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% &#123;} scenes$c       

T unk � заним solidity Steinمῆ period bindcannot">

.ال،
"' Bol

现在,输出甚至不是正确的 JSON 格式。

👉 受约束解码

为了强制输出 JSON 格式,我们将不得不使用受约束解码,即强制 LLM 只输出符合一组称为语法(grammar)的规则的词元(token)。

这个语法可以使用 Pydantic 模型、JSON schema 或正则表达式来定义。然后,AI 将生成一个符合指定语法的响应。

例如,在这里我们遵循 Pydantic 类型

from pydantic import BaseModel, confloat, StringConstraints
from typing import List, Annotated


class AnswerWithSnippets(BaseModel):
    answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]
    confidence: Annotated[float, confloat(ge=0.0, le=1.0)]
    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]

我建议检查生成的 schema 以确认它正确地表示了您的需求

AnswerWithSnippets.schema()

您可以使用客户端的 text_generation 方法或其 post 方法。

>>> # Using text_generation
>>> answer = llm_client.text_generation(
...     prompt,
...     grammar={"type": "json", "value": AnswerWithSnippets.schema()},
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)

>>> # Using post
>>> data = {
...     "inputs": prompt,
...     "parameters": {
...         "temperature": 1.6,
...         "return_full_text": False,
...         "grammar": {"type": "json", "value": AnswerWithSnippets.schema()},
...         "max_new_tokens": 250,
...     },
... }
>>> answer = json.loads(llm_client.post(json=data))[0]["generated_text"]
>>> print(answer)
&#123;
  "answer": "You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因",
            "confidence": 0.9,
            "source_snippets": ["in Transformers", "stop_sequence argument in your"]
            }
&#123;
"answer": "To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ",  "confidence": 1,  "source_snippets": ["seq이야","stration nhiên thị ji是什么hpeldo"]
}

✅ 尽管由于温度较高,答案仍然毫无意义,但生成的输出现在是正确的 JSON 格式,具有我们在语法中定义的确切键和类型!

然后可以将其解析以进行进一步处理。

在本地流水线中使用 Outlines 实现语法约束

Outlines 是我们推理 API 底层运行的库,用于约束输出生成。您也可以在本地使用它。

它的工作原理是对 logits 应用偏置,以强制只选择符合您约束的词元。

import outlines

repo_id = "mustafaaljadery/gemma-2B-10M"
# Load model locally
model = outlines.models.transformers(repo_id)

schema_as_str = json.dumps(AnswerWithSnippets.schema())

generator = outlines.generate.json(model, schema_as_str)

# Use the `generator` to sample an output from the model
result = generator(prompt)
print(result)

您也可以使用文本生成推理(Text-Generation-Inference)进行受约束的生成(有关更多详细信息和示例,请参阅文档)。

现在我们已经演示了一个特定的 RAG 用例,但受约束的生成远不止于此。

例如,在您的 LLM 评判者工作流中,您也可以使用受约束的生成来输出 JSON,如下所示

{
    "score": 1,
    "rationale": "The answer does not match the true answer at all."
    "confidence_level": 0.85
}

今天就到这里,恭喜您坚持下来!👏

< > 在 GitHub 上更新