复现 DeepSeek R1 用于信息抽取
自 DeepSeek R1 发布以来,我一直在努力复现它。我的主要重点是信息抽取,特别是零样本**文本到图谱抽取**。这是一项任务,即给定实体和关系类型列表,我们从目标文本中抽取实体列表以及它们之间的关系。
文本到图谱输出示例
{
"entities": [
{
"id": 0,
"text": "Microsoft",
"type": "company"
},
{
"id": 1,
"text": "Satya Nadella",
"type": "person"
},
{
"id": 2,
"text": "Azure AI",
"type": "product",
}
],
"relations": [
{
"head": "Satya Nadella",
"tail": "Microsoft",
"type": "CEO of"
},
{
"head": "Microsoft",
"tail": "Azure AI",
"type": "developed"
}
]
}
这是一个相当复杂的任务,特别是对于小型生成式语言模型。如果我们不对输出进行实体和关系类型的约束,并允许模型自由地从文本中抽取所有实体和关系,那么**语言模型**可以相对较好地完成这项任务。但是,当我们根据实体和关系类型对输出进行条件限制时,这对语言模型来说就变成了真正的噩梦。从我的实验来看,以监督方式训练小型语言模型以根据输入实体类型有条件地从文本中输出图谱是很困难的。强化学习方法带来了希望,所以让我们详细讨论一下。
**强化学习**与监督学习的不同之处在于,我们不明确告诉模型采取哪些行动才能达到理想的里程碑。在我们的例子中,里程碑是考虑输入实体和关系类型的正确抽取的图谱,行动是模型生成的令牌。我们可以直接告诉模型如何通过最大化生成所需格式输出的概率来重现这个图谱,比如说以 JSON 格式。
许多人谈论思维的重要性是强化学习为大型语言模型带来的主要推动力之一,尽管许多论文表明思维链提高了模型的性能,但思维提高性能看起来相当合理。然而,我认为强化学习的其他几个特性可能会产生重大影响。但首先让我们讨论 DeepSeek 引入的 **GRPO** 方法。我将团队使用的损失函数放在下面。
不深入数学,我将用高级术语讨论其含义。基本上,我们生成一组给定问题的候选解决方案,并根据其获得的奖励最大化返回问题解决方案的概率。此外,通过 KL 收敛分量,我们尝试最小化与作为起点原始模型的漂移。
这种训练算法可以带来有趣的特性,例如我们强制模型生成几个候选解决方案,这些解决方案对于给定问题而言是先验的**困难负样本**,因为模型为其生成分配了相对较高的概率。因此,在某种意义上,模型在训练期间看到了正样本和负样本。
此外,正如 Andrej Karpathy 指出的那样,直接标签示例无法强制模型利用其知识来推断一些新的突现属性,例如“啊哈”时刻
“模型永远无法通过模仿学习这一点,因为模型的认知和人类标注者的认知是不同的。人类永远不知道如何正确标注这些解决问题的策略,也不知道它们应该是什么样子。它们必须在强化学习过程中被发现,作为最终结果的经验和统计上有用的方法。”
强化学习的另一个有趣特性是,我们可以优化不同的目标并手动控制它们的影响,例如,如果我们发现模型在关系抽取方面遇到困难,我们可以为生成正确抽取关系的示例分配更高的奖励,以比较其他特征。
那么,让我们讨论一下我们究竟是如何使用 GRPO 训练模型进行文本到图谱推理的,您可以看到一个可视化图表显示了这一点
训练过程包括三个主要阶段:**合成数据生成、监督训练和强化学习 (RL) 训练**。这些阶段中的每一个都在提高模型执行结构化信息抽取的能力方面发挥着关键作用。
- 合成数据生成
为了启动这个过程,我们从**数据收集**开始,收集与我们目标领域相关的各种文本源。由 **Llama 70B** 结构化生成驱动的**文本到图谱**生成步骤将非结构化文本转换为基于图谱的表示。然而,这个步骤是不完善的,因此,选择和增强数据变得至关重要,以过滤掉低质量的抽取并用更多样化的结构丰富数据集。
此外,我们将生成的结构化预测 JSON 数据和文本输入到 **DeepSeek-R1 Llama 70B** 中,以生成可以解释抽取过程的思维链。
我们对思维启用和禁用模式进行了实验,发现小型模型难以发现一些有趣且重要的思维策略。
- 监督训练
在开始强化学习之前,考虑到我们使用小型模型,需要进行额外的监督训练以推动模型以正确的格式返回数据。我们为此目的仅使用了 1k 个示例。
- 使用 GRPO 进行强化学习
仅靠监督训练并不能完全解决问题,尤其是在根据预定义的实体和关系类型对模型输出进行条件限制时。为了解决这个问题,我们采用**组相对策略优化 (GRPO)** 进行强化学习。
- **格式奖励**确保输出遵循结构化格式,其中思维被封装在相应的标签中(在思维模式的情况下)。
- **JSON 奖励**专门验证格式良好、机器可读的 JSON 表示,并且其结构与所需的格式一致。
- **F1 奖励**通过将抽取到的实体和关系与真实图谱进行比较来评估其准确性。
我为奖励函数设置了不同的系数,优先考虑 F1 奖励,因为从我早期的实验来看,模型在生成小型 JSON 输出时陷入了局部最小值。
**强化学习**阶段允许模型动态调整其生成策略,在必要时强调正确的关系抽取。此外,**GRPO** 使模型能够生成多个候选解决方案,并从正样本和负样本中学习,从而实现更强大的**文本到图谱**抽取。
您可以在下面看到不同奖励随时间的变化,如您所见,**F1 奖励**不断增长,而 **JSON 奖励**由于监督预训练而迅速饱和。
模型在短时间的无监督学习后能够提高其性能,并且通过更多的强化学习训练步骤,其性能可能会更好。
我们计划进行更多实验,采用更大的模型和更高质量的数据,敬请关注。与此同时,您可以尝试我们实验中的一个模型
https://huggingface.co/Ihor/Text2Graph-R1-Qwen2.5-0.5b
要运行模型,请参阅下面的代码片段
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Ihor/Text2Graph-R1-Qwen2.5-0.5b"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = """Your text here..."""
prompt = "Analyze this text, identify the entities, and extract meaningful relationships as per given instructions:{}"
messages = [
{"role": "system", "content": (
"You are an assistant trained to process any text and extract named entities and relations from it. "
"Your task is to analyze user-provided text, identify all unique and contextually relevant entities, and infer meaningful relationships between them"
"Output the annotated data in JSON format, structured as follows:\n\n"
"""{"entities": [{"type": entity_type_0", "text": "entity_0", "id": 0}, "type": entity_type_1", "text": "entity_1", "id": 0}], "relations": [{"head": "entity_0", "tail": "entity_1", "type": "re_type_0"}]}"""
)},
{"role": "user", "content": prompt.format(text)}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
代码可以在这个 repo 中找到,非常感谢 Hugging Face Open-R1 和 TRL 项目。
本研究项目中使用的数据集可以在此处找到。
欢迎分享您的想法并提出问题!