🦙⚗️ 使用 Llama3 和 distilabel 构建微调数据集

社区文章 发布于 2024 年 6 月 4 日

在这篇文章中,我将解释如何使用 distilabel 和 Hugging Face 推理端点构建 LLM 微调数据集。

目标是什么?

Argilla,我们发布了一系列用于对齐 LLM 的有影响力的开放数据集。不幸的是,所有这些数据集都使用了闭源模型(主要是 GPT-4)来执行 AI 反馈 (AIF) 或 LLM-as-Judge 步骤。此步骤使用 LLM 来判断多个响应的质量,以便用于偏好调整。使用闭源模型的原因是,AI 反馈步骤需要一个强大且能力强的模型来近似人类偏好。最终目标是创建一个数据集,可用于使用 DPO、ORPO 或 KTO 等对齐方法改进开源模型。现在 Llama3 正在缩小性能差距,我们离我们的愿景更近了一步:完全开放的数据生成管道

准备材料

要从零开始构建高质量的偏好数据集,我们需要:

  • 一个包含提示的数据集:我使用 DIBT/10k_prompts_ranked。我喜欢这个数据集,因为它包含由 314 位出色的 DIBT 社区成员精心策划的高质量提示!如果你想自己查看数据,请查看 Argilla Space。我个人发现,花几分钟查看数据是了解 AI 模型和方法如何工作的最有效方式!
  • 一个或多个模型来生成对提示的响应:我使用 Llama3 模型(8B 和 70B 指令版本)。运行这些模型可能成本高昂,并且需要一定的技能才能部署它们。对于小型实验和原型,你可以使用 Inference for Pros。
  • 一个模型来判断生成响应的质量:如上所述,这是使用 Llama3-70B-Instruct 进行此操作的第一个示例之一。它肯定不会是最后一个!
  • 用于执行和编排数据生成管道的代码:你可以开发自己的代码来定义数据准备、配置、提示、推理代码等,或者你可以使用我们闪亮的新 distilabel 1.0,它大大简化了此过程,并提供了构建复杂数据合成和 AIF 管道所需的一切!
  • 人工反馈:我为此使用 Argilla。对我来说,这是关键一步,也是 distilabel 脱颖而出的原因:你可以通过一个美观、透明的用户界面让人工专家使用你的数据集。AI 生成的数据集存在许多限制(各种偏差、过度自信的评分、有限的推理能力等等)。如果你想制作高质量的数据集,我强烈建议你至少花几个小时验证生成的数据。即使你的资源有限,只想生成一个完全合成的数据集,你总能找到改进数据生成管道的方法(例如,请参阅我们关于 Notus 的工作)。对于更关键的用例,此步骤意味着你可以将 AI 数据集提供给你的专家池,然后才花费任何费用使用未知质量的数据微调模型。

秘诀

现在让我们看看如何创建一个 distilabel 管道,它接收我们的提示数据集并端到端地构建一个偏好数据集。

管道看起来像这样:

load_dataset \
> [generate_with_llama3_70B, generate_with_llama3_8B] \
> combine_columns \
> ultrafeedback \
> [keep_columns, push_to_argilla]

如果你想在阅读以下部分之前了解 distilabel 的工作原理,请查看这篇博客文章

对于实验,我正在使用 Inference for Pros,对于更大的数据集,你可以部署一个推理端点。你可以使用 InferenceEndpointsLLM ,将 model_id 替换为 endpoint_nameendpoint_namespace

加载数据集

此步骤加载源数据。distilabel 提供了一种从 Hub 读取数据集的便捷方法。由于数据管道复杂且可能消耗资源,我们建议从一个非常小的样本开始,以确保一切正常,然后再启动完整的生成作业。我使用了一小部分高质量的提示,利用了 LoadDataFromDicts 步骤。

# get great prompts annotated by at least 2 contributors
dataset = load_dataset(
  "DIBT/10k_prompts_ranked",
  split="train"
).filter(
  lambda r: r['avg_rating']>=4 and r['num_responses']>=2
)

dataset = dataset.to_list()
load_dataset = LoadDataFromDicts(
  name="load_dataset",
  data=dataset[0:500], # during development I used [0:1] to iterate quickly
  output_mappings={"prompt": "instruction"}
)

生成回复

这些步骤将从 load_dataset 步骤中获取提示并生成回复。我设置了两个将并行运行的步骤:

generate_with_llama3_70B = TextGeneration(
  name="generate_with_llama3_70B",
  llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
  ),
)

generate_with_llama3_8B = TextGeneration(
  name="generate_with_llama3_8B",
  llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-8B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
  ),
)

合并列

UltraFeedback LLM-as-Judge 步骤需要一个名为 generations 的响应列表(你可以使用 inputs_mapping 参数进行修改)。之前的并行步骤输出一个 generation(响应)和 model_name(生成它的模型)。“合并列”步骤为 UltraFeedback 准备输入。

combine_columns = CombineColumns(
  name="combine_columns",
  columns=["generation", "model_name"],
  output_columns=["generations", "generation_models"],
)

UltraFeedback

此步骤获取提示和 generations 下的两个响应,并使用 Llama3-70B-Instruct 将它们的质量从 1 到 5 进行评分。distilabel 为你提供了 UltraFeedback 最易于使用的实现,这是 OpenBMB 的一项开创性工作。为了使其更容易,我们开发了一个新颖的提示,用于评估响应的整体质量,同时考虑了原始 UltraFeedback 的维度(helpfulnesshonestyinstruction-followingtruthfulness)。

ultrafeedback = UltraFeedback(
  name="ultrafeedback",
  aspect='overall-rating',
  llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
  ),
  #llm=OpenAILLM(model="gpt-4"), # it used to be this! Now you can 🦙 -> 🥳
)

保留列

此步骤获取上一步的输出并保留我们想要的列

keep_columns = KeepColumns(
  name="keep_columns",
  columns=[
    "instruction",
    "generations",
    "generation_models",
    "ratings",
    "rationales",
  ],
)

推送到 Argilla 进行人工反馈

此步骤自动创建一个功能齐全的 UI,用于对生成的数据集进行人工反馈!

让我们看一下 Argilla UI 中的一个示例。

image/png

示例显示,我们的 AI 判官 (Llama3-70B) 认为 8B 模型的响应优于 70B 模型。

指令

不作任何解释,将 A→B 标记为“轻描淡写”、“夸大”或“相同”。A:新的 UCL 对果蝇的研究发现,当一天中的光照和温度水平不同步时,生物钟功能会崩溃。B:行为中断在很大程度上取决于光照和温度信号之间的相位关系。

来自 Llama3-70B 的响应(评分为 2)

A→B:轻描淡写

来自 Llama3-8B 的响应(评分为 5)

相同

它还提供了以下理由

标签“轻描淡写”是不正确的,因为第二句话 (B) 是对第一句话 (A) 更具体和详细的解释,而不是轻描淡写。模型的输出与指令不符,并且缺乏正确性。

标签“相同”是准确的,因为第二句话 (B) 是对第一句话 (A) 更详细的解释,没有轻描淡写或夸大信息。模型的输出是正确、自信的,并且与指令相符。

我非常喜欢这个例子,因为它

  • 它表明某些问题和评分回复是多么微妙和困难,即使对人类来说也是如此。对我来说,same 可能更准确,但有人可能会争辩说也存在轻微的轻描淡写(我也很想听听你对此的看法!)。
  • 基于以上内容,这是人类专家可以帮助提高训练数据质量的完美例子!
  • 它强调大型模型的响应并非总是最好的。我一直这样说。我们甚至在重新排名版本的 orca-pairs(已被 142 个模型使用!)中展示了在构建偏好数据集时不依赖此假设的影响。

最后,我已经将此数据集开放给社区。登录并亲自探索它,这是学习偏好数据集和 AI 反馈工作原理的最佳方式!

结果

包含 500 个示例的完整管道运行时间不到 30 分钟,费用为 0 美元。请查看最后一节查看完整代码。

distilabel 的一个重要特性是管道完全可重现,并且你可以通过 Hub 共享它们。我已经将此管道和数据集提供给社区:https://huggingface.co/datasets/dvilasuero/distillama3-prompts10k。这意味着你可以像这样自己运行它

distilabel pipeline run --config "https://huggingface.co/datasets/dvilasuero/distillama3-prompts10k/raw/main/pipeline.yaml"

下一步

这篇文章解释了基础知识并展示了一个端到端管道。在未来的文章中,我将把这个管道的结果与 GPT-4-turbo Judge 的结果进行比较,以了解我们离用闭源模型替换 AIF 数据集还有多远。

如果你读到这里:非常感谢你的阅读,如果你觉得有趣,请分享给你的朋友!

完整的管道代码

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineColumns, LoadDataFromDicts, KeepColumns, PreferenceToArgilla
from distilabel.steps.tasks import TextGeneration, UltraFeedback

from datasets import load_dataset

dataset = load_dataset("DIBT/10k_prompts_ranked", split="train").filter(lambda r: r['avg_rating']>=4 and r['num_responses']>=2)
dataset = dataset.to_list()

with Pipeline(
    name="prefs-with-llama-3",
    description="Pipeline for building preference datasets using Llama 3",
) as pipeline:
    load_dataset = LoadDataFromDicts(
        name="load_dataset",
        data=dataset[0:100],
        output_mappings={"prompt": "instruction"}
    )
    generate_with_llama3_70B = TextGeneration(
        name="generate_with_llama3",
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
        ),
    )

    generate_with_llama3_8B = TextGeneration(
        name="generate_with_llama3_8B",
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-8B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
        ),
    )

    combine_columns = CombineColumns(
      name="combine_columns",
      columns=["generation", "model_name"],
      output_columns=["generations", "generation_models"],
    )

    ultrafeedback = UltraFeedback(
      name="ultrafeedback",
      aspect='overall-rating',
      llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
        ),
      #llm=OpenAILLM(model="gpt-4"),
    )
    keep_columns = KeepColumns(
        name="keep_columns",
        columns=[
            "instruction",
            "generations",
            "generation_models",
            "ratings",
            "rationales",
        ],
    )

    # Push the generated dataset to Argilla
    # You need to  `pip install argilla`
    # and have an instance running: https://docs.argilla.com.cn/en/latest/getting_started/quickstart_installation.html
    push_to_argilla = PreferenceToArgilla(
        name="push_to_argilla",
        api_url="https://<argilla url>",
        api_key="<super secret api key>",
        dataset_name="ultrallama3",
        dataset_workspace="admin",
        num_generations=2,
    )

    generate_with_llama3_70B.connect(combine_columns)
    generate_with_llama3_8B.connect(combine_columns)

    load_dataset.connect(generate_with_llama3_70B)
    load_dataset.connect(generate_with_llama3_8B)
    combine_columns.connect(ultrafeedback)
    ultrafeedback.connect(keep_columns)
    ultrafeedback.connect(push_to_argilla)


if __name__ == "__main__":
    distiset = pipeline.run(
        parameters={
            "load_dataset": {
                "repo_id": "distilabel-internal-testing/instruction-dataset-mini",
                "split": "test",
            },
            "generate_with_llama3": {
                "llm": {
                    "generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.7, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
                }
            },
            "generate_with_llama3_8B": {
                "llm": {
                    "generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.7, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
                }
            },
            "ultrafeedback": {
                "llm": {
                    "generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.1, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
                }
            },
        }
    )
    distiset.push_to_hub("dvilasuero/distillama3-prompts10k")

社区

注册登录 以评论