⚗️ 🔥 使用 distilabel 和 Prometheus 2 构建高质量数据集

社区文章 发布于 2024 年 6 月 3 日

在这篇文章中,我将向您展示如何使用 distilabel 和 Prometheus 2 为微调大型语言模型(LLM)构建高质量数据集。Prometheus 2 是一个开源模型,专为评估 LLM 生成内容而设计,提供了 GPT-4 的经济高效替代方案。这种强大的组合使我们能够高效透明地蒸馏监督微调(SFT)和直接偏好优化(DPO)数据集。

以前,闭源模型如 GPT-4 对于可靠的 AI 反馈(AIF)来判断偏好调优响应的质量是必需的。有了 Prometheus 2 这个开源模型,我们现在可以更经济高效且透明地完成这项任务,为完全开放的数据生成管道做好准备。

本文将介绍两个使用 Prometheus 2 的常见合成数据集项目。首先,通过根据 Prometheus 2 评估移除低质量样本来蒸馏 SFT 数据集。其次,通过生成和评估响应来将 SFT 数据集扩展为 DPO 数据集。您可以按顺序使用这些管道,也可以单独使用并将它们与其他数据集结合。

🍟 Daniel Vila 上个月分享了这篇使用 distilabel、LLama3 和 UltraFeedback 的文章。在这里,我将扩展该管道,使用 Prometheus 2 而非 LLama3 和 UltraFeedback 进行判断。

UltraFeedback 与 Prometheus 2

UltraFeedback 和 Prometheus 2 都是评估语言模型输出的方法,但它们在方法和实现上存在显著差异。UltraFeedback 由 OpenBMB 开发,使用通用的高质量教师模型,通常是 GPT-4。它侧重于指令遵循、真实性、诚实性和有用性等多个方面,并生成单独的评分。另一方面,Prometheus 2 是一个开源模型,已在评估数据(来自 GPT-4)上进行了微调。Prometheus 2 可作为 GPT-4 的替代品,用于细粒度评估。它采用权重合并技术来支持绝对评分(直接评估单个内容)和相对评分(成对排名),使其适用于不同的评估需求。

1. 蒸馏 SFT 数据集

首先,我们将通过移除低质量样本来清理 SFT 数据集。

监督微调(SFT)数据集是指令和响应配对的集合,用于改进预训练语言模型在特定任务上的性能。这个过程涉及获取一个通用模型,并对其进行有针对性数据集的训练,其中每个指令(或提示)都有一个相应的理想响应。SFT 的目标是确保模型在未来给定类似提示时能够生成准确、相关和高质量的输出。通过使用这些精选数据对模型进行微调,我们可以显著提高其在所需应用上的表现,使输出更符合特定要求和用例。

所需材料

  • 包含提示和响应的数据集:使用 openbmb/UltraInteract_sft,这是一个由社区整理的高质量提示数据集。
  • 用于判断响应质量的模型:Prometheus 2 将评估响应,为 GPT-4 等闭源模型提供可靠的开源替代方案。

步骤

让我们逐步了解管道的每个步骤,以了解其作用。下面我将分享一个端到端的示例,您可以复制粘贴到自己的项目中。

步骤 1:加载数据集

首先使用 distilabel 加载源数据。从一个小样本开始,确保一切正常后再进行扩展。这里我使用 Open BMB 的 openbmb/UltraInteract_sft,但您可以从任何包含指令响应对的数据集开始。如果列名不同,可以使用 output_mapping 参数。

load_dataset = LoadHubDataset(
    name="load_dataset",
    repo_id="openbmb/UltraInteract_sft",
    split="train",
    batch_size=5,
    num_examples=100
)

步骤 2:使用 Prometheus 2 判断响应

使用 Prometheus 2 评估响应质量。我们将通过 distilabel 的集成 PrometheusEvalvLLM 中的模型权重加载 Prometheus 评估任务和提示。我们只评估一个样本,因此将使用 absolute 模式,而不是 relative。我们将专注于 factual-validity 准则。我们也可以在此处提供多个准则。有关模式和准则的详细信息,请查看 Prometheus 2 repo

prometheus = PrometheusEval(
    name="prometheus",
    llm=vLLM(
        model="prometheus-eval/prometheus-7b-v2.0",
        chat_template="[INST] {{ messages[0]['content'] }}\\n{{ messages[1]['content'] }}[/INST]",
    ),
    mode="absolute",
    rubric="factual-validity",
    reference=False,
    num_generations=1,
    group_generations=False,
)

2. 从 SFT 数据集构建 DPO 数据集

下一个管道着重于通过生成和评估额外的响应来创建直接偏好优化(DPO)数据集。

直接偏好优化(DPO)数据集旨在通过提供对同一指令不同响应之间的明确偏好来训练语言模型。它包含一个指令,后跟两个响应:一个“选中”或理想的响应,另一个“拒绝”的响应。这种设置允许模型直接从人类偏好中学习,优化其输出以更好地符合所需的响应。与需要复杂奖励模型的传统强化学习方法不同,DPO 通过将偏好学习视为一个简单的分类问题来简化过程,从而使其更稳定和高效。

所需材料

  • 初始 SFT 数据集:作为基础,包含指令-响应对。
  • 生成额外响应的模型:Llama3 模型(8B 和 70B 指令版本)用于生成 DPO 数据集所需的额外响应。
  • 评估响应质量的模型:Prometheus 2 将评估响应,确定哪个是“选中”响应,哪个是“拒绝”响应。

步骤

步骤 1:加载数据集

与 SFT 蒸馏一样,我们将使用 distilabel 的集成来加载数据集。事实上,您可以从任何具有提示列的数据集开始,因为我们将生成多个响应。

load_dataset = LoadHubDataset(
    name="load_dataset",
    repo_id="openbmb/UltraInteract_sft",
    split="train",
    batch_size=3,
    num_examples=3,
)

步骤 2:生成响应

此步骤从我们的 load_dataset 步骤中获取提示,并使用 Llama3 模型生成响应。通常使用多个预期质量不同的模型生成响应,以填充选中和拒绝的响应。Prometheus 2 将定义响应的质量。注意,我们也可以直接使用 SFT 数据集中的响应,但我将包含此示例,以便可以将其应用于更多数据集。

Inference is performed using Hugging face inference endpoints. In order to use extensively the serverless Inference Endpoints deployed in the Hugging Face Hub, subscribing to Pro is recommended (see [pricing](https://huggingface.co/pricing)), since [Inference for PROs](https://huggingface.co/blog/inference-pro) will be enabled and you will have improved rate limits for the usage of the free Inference API.
generate_with_llama3_70B = TextGeneration(
    name="generate_with_llama3_70B",
    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3-70B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    ),
)

generate_with_llama3_8B = TextGeneration(
    name="generate_with_llama3_8B",
    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3-8B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
    ),
)

步骤 3:合并列

通过将来自多个模型的生成内容合并到单个列中,准备 Prometheus 评估的输入。

combine_columns = CombineColumns(
    name="combine_columns",
    columns=["generation", "model_name"],
    output_columns=["generations", "generation_models"]
)

步骤 4:Prometheus 评估

使用 Prometheus 2 评估响应质量。我们将再次通过 distilabel 的集成 PrometheusEval 加载 Prometheus 评估任务和提示。对于 DPO,我们将使用 relative 模式,因为我们正在比较响应。有关模式和准则的详细信息,请查看 Prometheus 2 仓库

prometheus = PrometheusEval(
    name="prometheus",
    llm=vLLM(
        model="prometheus-eval/prometheus-7b-v2.0",
        chat_template="[INST] {{ messages[0]['content'] }}\\n{{ messages[1]['content'] }}[/INST]",
    ),
    mode="relative",
    rubric="factual-validity",
    reference=False,
    num_generations=1,
    group_generations=False,
)

步骤 5:保留列

保留最终数据集所需的列。

keep_columns = KeepColumns(
    name="keep_columns",
    columns=["instruction", "generations", "feedback", "result", "model_name"]
)

后续步骤

在了解了使用 Prometheus 2 和 distilabel 蒸馏 SFT 数据集并将其扩展为 DPO 数据集的过程之后,接下来可以探索几个令人兴奋的方向。首先,您可以尝试 Prometheus 2 中不同的准则和评估模式,看看它们如何影响数据集的质量和性能。此外,考虑扩大数据集规模并进行更广泛的评估,以进一步验证改进。另一个有价值的步骤是整合人工反馈,利用 Argilla 等平台来改进模型输出。

完整管道示例

以下是运行这些管道的端到端示例,有关进一步指导,请参阅 distilabel 文档

SFT 管道

from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns, LoadHubDataset, CombineColumns
from distilabel.steps.tasks import PrometheusEval, TextGeneration
from distilabel.llms import vLLM

with Pipeline(name="prometheus-SFT") as pipeline:

    load_dataset = LoadHubDataset(
        name="load_dataset",
        repo_id="openbmb/UltraInteract_sft",
        split="train",
        batch_size=5,
        num_examples=3,
    )

    prometheus = PrometheusEval(
        name="prometheus",
        llm=vLLM(
                model="prometheus-eval/prometheus-7b-v2.0",
                chat_template="[INST] {{ messages[0]['content'] }}\\n{{ messages[1]['content'] }}[/INST]",
            ),
        mode="absolute",
        rubric="factual-validity",
        reference=False,
        num_generations=1,
        group_generations=False,
    )

    keep_columns = KeepColumns(
        name="keep_columns",
        columns=["instruction", "generation", "result", "model_name", "feedback"],
    )

    load_dataset.connect(prometheus)
    prometheus.connect(keep_columns)

DPO 管道

from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns, LoadHubDataset, CombineColumns
from distilabel.steps.tasks import PrometheusEval, TextGeneration
from distilabel.llms import vLLM

with Pipeline(name="prometheus-DPO") as pipeline:

    load_dataset = LoadHubDataset(
        name="load_dataset",
        repo_id="openbmb/UltraInteract_sft",
        split="train",
        batch_size=3,
        num_examples=3,
    )

    generate_with_llama3_70B = TextGeneration(
        name="generate_with_llama3_70B",
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
        ),
    )
    
    generate_with_llama3_8B = TextGeneration(
        name="generate_with_llama3_8B",
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-8B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
        ),
    )

    combine_columns = CombineColumns(
      name="combine_columns",
      columns=["generation", "model_name"],
      output_columns=["generations", "generation_models"],
    )

    prometheus = PrometheusEval(
        name="prometheus",
        llm=vLLM(
                model="prometheus-eval/prometheus-7b-v2.0",
                chat_template="[INST] {{ messages[0]['content'] }}\\n{{ messages[1]['content'] }}[/INST]",
            ),
        mode="relative",
        rubric="factual-validity",
        reference=False,
        num_generations=1,
        group_generations=False,
    )

    keep_columns = KeepColumns(
        name="keep_columns",
        columns=["instruction", "generations", "feedback", "result", "model_name"],
    )

    push_to_argilla = DPOToArgilla(
        name="push_to_argilla"
    )

    load_dataset.connect(combine_columns)
    load_dataset.connect(generate_with_llama3_70B)
    load_dataset.connect(generate_with_llama3_8B)
    generate_with_llama3_70B.connect(combine_columns)
    generate_with_llama3_8B.connect(combine_columns)
    combine_columns.connect(prometheus)
    prometheus.connect(keep_columns)

资源

社区

注册登录 发表评论