使用合成数据微调 ModernBERT 以用于 RAG
检索增强生成 (RAG) 是一种广泛采用的问答系统构建框架。通过从知识库(无论是网络还是您的文档)中检索相关信息,RAG 提高了用户的信任度和可靠性。它提供最新且可验证的领域特定数据,同时更高效、更具成本效益,无需进行 LLM 后训练。
为了提高 RAG 系统生成的响应质量,拥有性能良好的检索和重排模型至关重要。为此,我们可以使用我们自己的数据对它们进行微调,以便它们能够准确识别相关信息并对其进行排序。然而,微调它们需要与您的任务相关的额外数据,而这些数据并不总是可用的。
这篇博客文章将展示如何使用您的文档微调检索和重排模型。利用这些文档,我们可以创建代表您领域的合成训练数据。这使您即使在真实世界数据稀缺的情况下也能提高性能。在我们的用例中,我们将改进一个响应人权和公民权利法律文档的 RAG 系统。
目录
为 RAG 生成合成数据
第一步是为 RAG 生成合成数据。我们将使用合成数据生成器,这是一个用户友好的应用程序,无需代码即可使用 LLM 创建自定义数据集。
有关合成数据生成器的详细信息和用法,请查看合成数据生成器简介——使用自然语言构建数据集和原始GitHub 存储库。
使用合成数据生成器生成数据是一个简单的过程,只涉及三个关键步骤:
- 选择输入数据:选择一个具有代表性的样本数据集、反映目标知识库结构和特征的文档,甚至可以从头开始指定数据集类型进行生成。
- 配置生成器:设置生成器参数并迭代样本数据集以优化和验证生成过程。
- 生成数据集:一旦配置优化完成,使用生成器创建完整的合成数据集。
在以下部分中,我们将探讨这些步骤以指导您。
选择输入数据
为了为您的用例生成相关信息,您可以提供一个信息源,形式为 Hub 中的数据集,或直接上传 .pdf、.md、.txt 或 .docx 等格式的原始文档。此外,您也可以编写数据集描述,其中应概述您的检索任务的主题、范围和具体要求,确保生成的数据既相关又符合目的。
在我们的示例中,我们找不到包含人权信息的合适数据集。相反,我们创建了两个数据集。第一个数据集我们提供了两个 PDF 文件:《欧洲人权公约》和《世界人权宣言》。这些文件是生成合成数据的基础,确保数据与主题紧密相关。对于第二个数据集,我们选择通过编写详细的数据集描述来拓宽范围,提供更广泛的人权视角。通过结合这些方法,我们确保了我们的数据集既涵盖了主题的具体方面,也涵盖了通用方面,从而提高了生成数据的多功能性。
配置生成器
接下来,我们将遍历一个样本数据集以配置生成参数。此配置会根据所选输入略有不同。
- 如果您选择数据集或原始文件作为输入(其会自动分块以方便处理),您首先需要选择包含信息片段的列。
- 当使用数据集描述时,系统会自动根据您的描述生成一个系统提示。此提示概述了任务检索,并可以根据需要重新生成或修改,以更好地满足您的需求。
无论输入类型如何,您都可以在此步骤中添加数据。如果未应用其他设置,生成的 dataset 将包括三列:上下文(Context)、问题(Question)和答案(Answer)。如果指定了检索,输出将包括正向和负向查询。如果指定了重排,输出将包括基于上下文的正向和负向示例。在我们的例子中,我们选择了检索和重排来微调这些任务的模型。
生成数据集
完成以上步骤后,我们就可以生成完整的数据集了!生成的数据集将自动在 Hub 和 Argilla 中可用,随时供审查和使用。
我们使用无服务器推理 API 为每个数据集生成了 500 行,从源文件生成大约需要 40 分钟,从数据集描述生成大约需要 1 小时。
太棒了!您已经掌握了合成数据生成器的使用。让我们进入下一步:使用我们的数据训练模型!
以下部分提供了简化代码片段,以便于理解。如果您想探索完整的实现,可以在这里访问完整的笔记本。
训练模型
为了优化检索,我们将使用 bi-encoder(更快但准确性较低)微调一个句子相似性模型,并使用 cross-encoder(更慢但准确性更高)进行重排。为此,我们将使用Sentence Transformers库和基于 ModernBERT-base 训练的嵌入模型nomic-ai/modernbert-embed-base。
bi-encoder 和 cross-encoder 有什么区别?
bi-encoder 为数据和查询创建句子嵌入,然后通过计算向量之间的相似性进行比较。cross-encoder 不使用句子嵌入,而是对数据对进行分类并输出一个指示其相似性的值。它们可以独立使用,也可以在检索器中一起使用,其中检索是初始步骤,涉及在庞大数据集或集合中搜索以识别可能与给定查询或信息需求相关的候选文档、段落或句子。在此之后,进行重排阶段,其中最初检索到的候选者根据其与查询的实际相关性进行重新评估和重新组织。
预处理生成的数据
在训练模型之前,我们将合并数据集,过滤、清理它们,并为检索和重排做好准备。对于检索,我们将使用三元组(锚点、正例和负例)。在重排的情况下,不建议使用三元组,我们将使用句子对(锚点和正例)和相似度分数,因此我们将使用基于 MTEB 排行榜的 Snowflake/snowflake-arctic-embed-m-v1.5 计算相似度分数。
# Load the datasets and combine them
dataset_rag_from_file = load_dataset(f"{REPO_NAME}/rag-human-rights-from-files", split="train")
dataset_rag_from_prompt = load_dataset(f"{REPO_NAME}/rag-human-rights-from-prompt", split="train")
combined_rag_dataset = concatenate_datasets(
[dataset_rag_from_file, dataset_rag_from_prompt]
)
# Filter the empty and NaN values
filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)
# Format the data for retrieval and reranking
clean_rag_dataset_biencoder = rename_and_reorder_columns(
filtered_rag_dataset,
rename_map={"context": "anchor", "positive_retrieval": "positive", "negative_retrieval": "negative"},
selected_columns=["anchor", "positive", "negative"],
)
clean_rag_dataset_crossencoder = rename_and_reorder_columns(
filtered_rag_dataset,
rename_map={"context": "anchor", "positive_retrieval": "positive"}, #TODO
selected_columns=["anchor", "positive"],
)
# Add scores for reranking
clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(
add_reranking_scores, batched=True, batch_size=250
)
# Split the datasets
dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)
dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)
训练 Bi-encoder 用于检索
现在,我们可以初始化模型并开始训练。根据您的资源需求配置训练参数,以提高性能和准确性。这将推送我们的 sdiazlor/modernbert-embed-base-biencoder-human-rights 模型。
# Initialize the SentenceTransformer model
model_biencoder = SentenceTransformer(
MODEL,
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name=MODEL_NAME_BIENCODER,
),
# Train the model
trainer = SentenceTransformerTrainer(
model=model_biencoder,
args=training_args,
train_dataset=dataset_rag_biencoder["train"],
eval_dataset=dataset_rag_biencoder["eval"],
loss=loss_biencoder,
evaluator=triplet_evaluator,
)
trainer.train()
# Save the model to the local directory and push it to the Hub
model_biencoder.save_pretrained(f"models/{MODEL_NAME_BIENCODER}")
model_biencoder.push_to_hub(f"{REPO_NAME}/{MODEL_NAME_BIENCODER}")
训练 Cross-encoder 用于重排
之后,我们可以开始训练 cross-encoder。我们将标签数量设置为 1,因为它是一个回归任务。这将推送我们的 sdiazlor/modernbert-embed-base-crossencoder-human-rights 模型。
# Initialize the CrossEncoder model
model_crossencoder = CrossEncoder(model_name=MODEL, num_labels=1)
# Train the model
model_crossencoder.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=3,
warmup_steps=500,
output_path=f"models/{MODEL_NAME_CROSSENCODER}",
save_best_model=True,
)
# Save the model to the local directory and push it to the Hub
model_crossencoder.save_pretrained(f"models/{MODEL_NAME_CROSSENCODER}")
model_crossencoder.push_to_hub(f"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}")
瞧!我们已经成功地训练了用于检索和重排的模型。在我们的案例中,每个模型的训练过程大约花费了 1 小时。但是,请记住,训练持续时间会根据训练参数和使用的样本数量而显著变化。
随意尝试这些配置,以优化性能或仅仅探索不同的设置如何影响结果。
构建您的 RAG 流水线
准备好使用您的模型了吗?我们将使用 Haystack,一个用于构建生产级 LLM 应用程序、检索增强生成管道和最先进搜索系统的开源框架。因此,我们将使用检索器(bi-encoder 模型)、排序器(cross-encoder 模型)和 meta-llama/Llama-3.1-8B-Instruct 作为 LLM 来构建 RAG 管道。
# Initialize the pipeline with the components
rag_pipeline = Pipeline()
rag_pipeline.add_component("text_embedder", text_embedder)
rag_pipeline.add_component("retriever", retriever)
rag_pipeline.add_component("ranker", ranker)
rag_pipeline.add_component("prompt_builder", prompt_builder)
rag_pipeline.add_component("llm", chat_generator)
# Connect the components to each other
rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
rag_pipeline.connect("retriever.documents", "ranker.documents")
rag_pipeline.connect("ranker", "prompt_builder")
rag_pipeline.connect("prompt_builder.prompt", "llm.messages")
一旦我们有了管道,我们就可以开始向系统提问了。
response = rag_pipeline.run(
{
"text_embedder": {"text": question},
"prompt_builder": {"question": question},
"ranker": {"query": question},
}
)
根据所提供的文档和您微调的模型,它将获取信息或指示是否缺少某些数据。例如
# A response lacking information with the base model
Unfortunately, the text doesn't provide a specific answer to the question of how many human rights there are. It discusses various human rights conventions, protocols, and laws from different countries and regions, but it doesn't provide a comprehensive list or a definitive answer to the question.
# A response lacking information with the fine-tuned model
It seems that there is not enough information given in the human rights protocols provided to accurately answer the question. However, we can inform you that there are several types of human rights documents that this could be referring to.[...]
Not possible to answer your question due to lack of information, however we can tell you the most widely respected declared world document on human rights.
# A response with the base model
The question is incomplete. However, based on the information provided, I can infer that the correct information might be related to the equality right mentioned in various constitutions and human rights frameworks. Here's a possible answer:
The Right to a Fair Trial is not explicitly mentioned in the provided text. However, equality before the law and freedom from arbitrary detention are fundamental rights protected in various constitutions and human rights frameworks. [...]
# A response with the fine-tuned model
The information you provided does not directly list the "Right of Fair Trial" but looking under articles of the Convention for the Protection of Human Rights and Fundamental Freedoms, Article 6, also known as the Right to a Fair Trial, gives a clear idea.
Article 6. Right to a fair Trial [...]
下一步
在这篇博客中,我们展示了构建 RAG 系统的完整工作流程,从针对我们自定义用例生成 RAG 的合成数据,到微调检索和重排模型,最后构建完整的管道。
您还可以查看合成数据生成器中其余可用任务的博客文章:
您还在等什么?第一步:开始合成!