开源 AI 食谱文档

使用结构化生成进行 RAG 源码高亮

Hugging Face's logo
加入 Hugging Face 社区

并获取增强型文档体验

开始

Open In Colab

使用结构化生成进行 RAG 源码高亮

作者:Aymeric Roucher

结构化生成是一种方法,它强制 LLM 输出遵循某些约束,例如遵循特定模式。

它有许多用例

  • ✅ 输出具有特定键的字典
  • 📏 确保输出长度大于 N 个字符
  • ⚙️ 更一般地说,强制输出遵循特定正则表达式模式以供后续处理。
  • 💡 在检索增强生成 (RAG) 中高亮显示支持答案的来源

在本笔记本中,我们专门演示了最后一个用例

➡️ 我们构建了一个 RAG 系统,它不仅提供答案,还高亮显示了该答案所基于的支持片段。

如果您需要 RAG 入门介绍,您可以查看另一个食谱

本笔记本首先展示了通过提示进行结构化生成的一种简单方法,并突出了其局限性,然后演示了约束解码以实现更有效的结构化生成。

它利用 HuggingFace 推理端点(示例展示了一个无服务器端点,但您可以直接将端点更改为专用端点),然后还展示了使用outlines(一个结构化文本生成库)的本地管道。

!pip install pandas json huggingface_hub pydantic outlines accelerate -q
import pandas as pd
import json
from huggingface_hub import InferenceClient

pd.set_option("display.max_colwidth", None)
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"

llm_client = InferenceClient(model=repo_id, timeout=120)

# Test your LLM client
llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)

提示模型

要从您的模型中获取结构化输出,您可以简单地使用适当的指南提示功能强大的模型,它应该直接起作用...... 大多数情况下。

在这种情况下,我们希望 RAG 模型不仅生成答案,还要生成置信度得分和一些源片段。我们希望将这些生成为 JSON 字典,以便轻松地解析它们以进行后续处理(这里我们只会高亮显示源片段)。

RELEVANT_CONTEXT = """
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

"""
RAG_PROMPT_TEMPLATE_JSON = """
Answer the user query based on the source documents.

Here are the source documents: {context}


You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}}
End of answer.

Now begin!
Here is the user question: {user_query}.
Answer:
"""
USER_QUERY = "How can I define a stop sequence in Transformers?"
>>> prompt = RAG_PROMPT_TEMPLATE_JSON.format(context=RELEVANT_CONTEXT, user_query=USER_QUERY)
>>> print(prompt)
Answer the user query based on the source documents.

Here are the source documents: 
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.




You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}
End of answer.

Now begin!
Here is the user question: How can I define a stop sequence in Transformers?.
Answer:
>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=1000,
... )

>>> answer = answer.split("End of answer.")[0]
>>> print(answer)
{
  "answer": "You should pass the stop_sequence argument in your pipeline or model.",
  "confidence_score": 0.9,
  "source_snippets": ["stop_sequence", "pipeline or model"]
}

LLM 的输出是字典的字符串表示形式:所以让我们使用 literal_eval 将其加载为字典。

from ast import literal_eval

parsed_answer = literal_eval(answer)
>>> def highlight(s):
...     return "\x1b[1;32m" + s + "\x1b[0m"


>>> def print_results(answer, source_text, highlight_snippets):
...     print("Answer:", highlight(answer))
...     print("\n\n", "=" * 10 + " Source documents " + "=" * 10)
...     for snippet in highlight_snippets:
...         source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))
...     print(source_text)


>>> print_results(parsed_answer["answer"], RELEVANT_CONTEXT, parsed_answer["source_snippets"])
Answer: You should pass the stop_sequence argument in your pipeline or model.


 ========== Source documents ==========

Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

它起作用了!🥳

但使用功能较弱的模型会怎样?

为了模拟功能较弱的模型可能产生的不太连贯的输出,我们提高温度。

>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)
{
  "answer": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes	Render“He caτων不是來rates‏ 그런Received05jet �	DECLAREed "]";
Top Access臣Zen PastFlow.TabBand                                                
.Assquoas 믿锦encers relativ巨 durations........ $块 leftイStaffuddled/HlibBR、【(cardospelrowth)\<午…)_SHADERprovided["_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro	dis}')
———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,
_FOR NI/disable зн 主posureWiders,latRU_BUSY&#123;amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th	send-per '
nicasv:<:',
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% &#123;} scenes$c       

T unk � заним solidity Steinمῆ period bindcannot">

.ال،
"' Bol

现在,输出甚至不是正确的 JSON。

👉 受限解码

为了强制输出 JSON,我们需要使用**受限解码**,我们强制 LLM 仅输出符合一组称为**语法**的规则的标记。

此语法可以使用 Pydantic 模型、JSON 模式或正则表达式定义。然后 AI 将生成符合指定语法的响应。

例如,这里我们遵循Pydantic 类型

from pydantic import BaseModel, confloat, StringConstraints
from typing import List, Annotated


class AnswerWithSnippets(BaseModel):
    answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]
    confidence: Annotated[float, confloat(ge=0.0, le=1.0)]
    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]

我建议检查生成的模式以确保它正确地表示您的需求

AnswerWithSnippets.schema()

您可以使用客户端的text_generation方法或使用其post方法。

>>> # Using text_generation
>>> answer = llm_client.text_generation(
...     prompt,
...     grammar={"type": "json", "value": AnswerWithSnippets.schema()},
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)

>>> # Using post
>>> data = {
...     "inputs": prompt,
...     "parameters": {
...         "temperature": 1.6,
...         "return_full_text": False,
...         "grammar": {"type": "json", "value": AnswerWithSnippets.schema()},
...         "max_new_tokens": 250,
...     },
... }
>>> answer = json.loads(llm_client.post(json=data))[0]["generated_text"]
>>> print(answer)
&#123;
  "answer": "You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因",
            "confidence": 0.9,
            "source_snippets": ["in Transformers", "stop_sequence argument in your"]
            }
&#123;
"answer": "To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ",  "confidence": 1,  "source_snippets": ["seq이야","stration nhiên thị ji是什么hpeldo"]
}

✅ 虽然答案仍然毫无意义,因为温度很高,但生成的输出现在是正确的 JSON 格式,具有我们在语法中定义的精确键和类型!

然后可以对其进行解析以进行进一步处理。

使用 Outlines 在本地管道上使用语法

Outlines 是在我们的推理 API 下运行的库,用于约束输出生成。您也可以在本地使用它。

它的工作原理是对 logits 应用偏差 以强制仅选择符合您的约束的 logits。

import outlines

repo_id = "mustafaaljadery/gemma-2B-10M"
# Load model locally
model = outlines.models.transformers(repo_id)

schema_as_str = json.dumps(AnswerWithSnippets.schema())

generator = outlines.generate.json(model, schema_as_str)

# Use the `generator` to sample an output from the model
result = generator(prompt)
print(result)

您还可以使用文本生成推理 以及受限生成(有关更多详细信息和示例,请参阅文档)。

现在我们演示了一个特定的 RAG 用例,但受限生成对该用例更有帮助。

例如,在您的LLM 评判 工作流程中,您还可以使用受限生成来输出 JSON,如下所示

{
    "score": 1,
    "rationale": "The answer does not match the true answer at all."
    "confidence_level": 0.85
}

今天就到这里,恭喜您跟到最后!👏

< > 在 GitHub 上更新