开源 AI 食谱文档

使用结构化生成进行 RAG 和源突出显示

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

使用结构化生成进行 RAG 和源突出显示

作者:Aymeric Roucher

结构化生成是一种强制 LLM 输出遵循特定约束的方法,例如遵循特定模式。

这有许多用例

  • ✅ 输出具有特定键的字典
  • 📏 确保输出长度超过 N 个字符
  • ⚙️ 更一般地,强制输出遵循特定的正则表达式模式以进行下游处理。
  • 💡 突出显示检索增强生成 (RAG) 中支持答案的来源

在本笔记本中,我们专门演示最后一个用例

➡️ 我们构建了一个 RAG 系统,该系统不仅提供答案,还突出显示支持该答案的来源片段。

如果您需要 RAG 的介绍,您可以查看此其他食谱

本笔记本首先展示了一种通过提示进行结构化生成的朴素方法,并突出了其局限性,然后演示了约束解码以实现更有效的结构化生成。

它利用 HuggingFace 推理终结点(示例显示无服务器终结点,但您可以直接将终结点更改为专用终结点),然后还展示了使用outlines(一个结构化文本生成库)的本地管道。

!pip install pandas json huggingface_hub pydantic outlines accelerate -q
import pandas as pd
import json
from huggingface_hub import InferenceClient

pd.set_option("display.max_colwidth", None)
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"

llm_client = InferenceClient(model=repo_id, timeout=120)

# Test your LLM client
llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)

提示模型

为了从您的模型中获得结构化输出,您只需使用适当的指南提示足够强大的模型,它应该可以直接工作……大多数情况下。

在本例中,我们希望 RAG 模型不仅生成答案,还生成置信度分数和一些源片段。我们希望将这些生成为 JSON 字典,以便轻松解析它以进行下游处理(这里我们将仅突出显示源片段)。

RELEVANT_CONTEXT = """
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

"""
RAG_PROMPT_TEMPLATE_JSON = """
Answer the user query based on the source documents.

Here are the source documents: {context}


You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}}
End of answer.

Now begin!
Here is the user question: {user_query}.
Answer:
"""
USER_QUERY = "How can I define a stop sequence in Transformers?"
>>> prompt = RAG_PROMPT_TEMPLATE_JSON.format(context=RELEVANT_CONTEXT, user_query=USER_QUERY)
>>> print(prompt)
Answer the user query based on the source documents.

Here are the source documents: 
Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.




You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}
End of answer.

Now begin!
Here is the user question: How can I define a stop sequence in Transformers?.
Answer:
>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=1000,
... )

>>> answer = answer.split("End of answer.")[0]
>>> print(answer)
{
  "answer": "You should pass the stop_sequence argument in your pipeline or model.",
  "confidence_score": 0.9,
  "source_snippets": ["stop_sequence", "pipeline or model"]
}

LLM 的输出是字典的字符串表示形式:因此,让我们使用 literal_eval 将其加载为字典。

from ast import literal_eval

parsed_answer = literal_eval(answer)
>>> def highlight(s):
...     return "\x1b[1;32m" + s + "\x1b[0m"


>>> def print_results(answer, source_text, highlight_snippets):
...     print("Answer:", highlight(answer))
...     print("\n\n", "=" * 10 + " Source documents " + "=" * 10)
...     for snippet in highlight_snippets:
...         source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))
...     print(source_text)


>>> print_results(parsed_answer["answer"], RELEVANT_CONTEXT, parsed_answer["source_snippets"])
Answer: You should pass the stop_sequence argument in your pipeline or model.


 ========== Source documents ==========

Document:

The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.

这行得通!🥳

但是,如果使用功能较弱的模型呢?

为了模拟功能较弱的模型可能产生的较不连贯的输出,我们提高了温度。

>>> answer = llm_client.text_generation(
...     prompt,
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)
{
  "answer": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes	Render“He caτων不是來rates‏ 그런Received05jet �	DECLAREed "]";
Top Access臣Zen PastFlow.TabBand                                                
.Assquoas 믿锦encers relativ巨 durations........ $块 leftイStaffuddled/HlibBR、【(cardospelrowth)\<午…)_SHADERprovided["_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro	dis}')
———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,
_FOR NI/disable зн 主posureWiders,latRU_BUSY&#123;amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th	send-per '
nicasv:<:',
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% &#123;} scenes$c       

T unk � заним solidity Steinمῆ period bindcannot">

.ال،
"' Bol

现在,输出甚至不是正确的 JSON。

👉 约束解码

为了强制 JSON 输出,我们将不得不使用约束解码,其中我们强制 LLM 仅输出符合一组称为语法的规则的标记。

可以使用 Pydantic 模型、JSON 模式或正则表达式定义此语法。然后,AI 将生成符合指定语法的响应。

例如,在这里我们遵循Pydantic 类型

from pydantic import BaseModel, confloat, StringConstraints
from typing import List, Annotated


class AnswerWithSnippets(BaseModel):
    answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]
    confidence: Annotated[float, confloat(ge=0.0, le=1.0)]
    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]

我建议检查生成的模式,以确认其正确表示了您的要求

AnswerWithSnippets.schema()

您可以使用客户端的 text_generation 方法或使用其 post 方法。

>>> # Using text_generation
>>> answer = llm_client.text_generation(
...     prompt,
...     grammar={"type": "json", "value": AnswerWithSnippets.schema()},
...     max_new_tokens=250,
...     temperature=1.6,
...     return_full_text=False,
... )
>>> print(answer)

>>> # Using post
>>> data = {
...     "inputs": prompt,
...     "parameters": {
...         "temperature": 1.6,
...         "return_full_text": False,
...         "grammar": {"type": "json", "value": AnswerWithSnippets.schema()},
...         "max_new_tokens": 250,
...     },
... }
>>> answer = json.loads(llm_client.post(json=data))[0]["generated_text"]
>>> print(answer)
&#123;
  "answer": "You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因",
            "confidence": 0.9,
            "source_snippets": ["in Transformers", "stop_sequence argument in your"]
            }
&#123;
"answer": "To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ",  "confidence": 1,  "source_snippets": ["seq이야","stration nhiên thị ji是什么hpeldo"]
}

✅ 尽管由于高温,答案仍然没有意义,但生成的输出现在是正确的 JSON 格式,具有我们在语法中定义的精确键和类型!

然后可以对其进行解析以进行进一步处理。

在本地管道上使用 Outlines 的语法

Outlines 是在我们的推理 API 下运行以约束输出生成的库。您也可以在本地使用它。

它的工作原理是对 logits 应用偏差,以强制仅选择符合您约束的那些。

import outlines

repo_id = "mustafaaljadery/gemma-2B-10M"
# Load model locally
model = outlines.models.transformers(repo_id)

schema_as_str = json.dumps(AnswerWithSnippets.schema())

generator = outlines.generate.json(model, schema_as_str)

# Use the `generator` to sample an output from the model
result = generator(prompt)
print(result)

您还可以将Text-Generation-Inference 与约束生成一起使用(有关更多详细信息和示例,请参阅文档)。

现在我们已经演示了一个特定的 RAG 用例,但约束生成的作用远不止于此。

例如,在您的LLM judge 工作流程中,您也可以使用约束生成来输出 JSON,如下所示

{
    "score": 1,
    "rationale": "The answer does not match the true answer at all."
    "confidence_level": 0.85
}

今天就到这里,恭喜您一路跟进!👏

< > 在 GitHub 上更新