使用 LLM 作为评判员清理现有偏好数据集
作者:David Berenstein 和 Sara Han Díaz
- 库: argilla, hf-inference-endpoints
- 组件: LoadDataFromDicts, UltraFeedback, KeepColumns, PreferenceToArgilla, InferenceEndpointsLLM, GlobalStep
在本教程中,我们将使用 distilabel 通过提供对数据质量的 AI 反馈,使用 LLM 作为评判员来清理数据集。 distilabel 是一个合成数据和 AI 反馈框架,适用于需要基于经过验证的研究论文的快速、可靠且可扩展的管道工程师。查看文档 这里.
为了评估响应,我们将使用与 distilabel 集成的 无服务器 HF 推理 API。这是免费的,但有速率限制,允许您通过简单的 HTTP 请求测试和评估超过 150,000 个公共模型或您自己的私有模型,并在 Hugging Face 共享基础设施上进行快速推理。如果您需要更多计算能力,您可以使用 Hugging Face 推理端点 部署您自己的推理端点。
最后,为了进一步整理数据,我们将使用 Argilla,它允许我们对数据质量提供人工反馈。Argilla 是一个协作工具,适用于需要为其项目构建高质量数据集的 AI 工程师和领域专家。查看文档 这里.
入门
安装依赖项
要完成本教程,您需要通过 pip 安装 distilabel SDK 和一些第三方库。
!pip install "distilabel[hf-inference-endpoints]"
!pip install "transformers~=4.0" "torch~=2.0"
让我们进行必要的导入
import random
from datasets import load_dataset
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import (
KeepColumns,
LoadDataFromDicts,
PreferenceToArgilla,
)
from distilabel.steps.tasks import UltraFeedback
您需要一个HF_TOKEN
才能使用 HF 推理端点。登录以直接在该笔记本中使用它。
import os
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
(可选) 部署 Argilla
您可以跳过此步骤或将其替换为任何其他数据评估工具,但您的模型质量会因缺乏数据质量而下降,因此我们建议您查看您的数据。如果您已经部署了 Argilla,则可以跳过此步骤。否则,您可以按照本指南快速部署 Argilla。
除此之外,您还需要将 Argilla 作为 distilabel 额外功能安装。
!pip install "distilabel[argilla, hf-inference-endpoints]"
数据集
在本例中,我们将清理一个偏好数据集,因此我们将使用 Hugging Face Hub 中的 Intel/orca_dpo_pairs
数据集。
dataset = load_dataset("Intel/orca_dpo_pairs", split="train[:20]")
接下来,我们将对chosen
和 rejected
列进行洗牌,以避免数据集中的任何偏差。
def shuffle_and_track(chosen, rejected):
pair = [chosen, rejected]
random.shuffle(pair)
order = ["chosen" if x == chosen else "rejected" for x in pair]
return {"generations": pair, "order": order}
dataset = dataset.map(lambda x: shuffle_and_track(x["chosen"], x["rejected"]))
dataset = dataset.to_list()
(可选) 创建自定义步骤
步骤是 distilabel 管道中的一个块,用于在其他任务中操作、生成或评估数据。提供了一组预定义的步骤,但您也可以创建自己的自定义步骤。与上一节中对数据进行预处理不同,可以使用自定义步骤来对列进行洗牌。此步骤应位于单独的模块中,以便导入并在管道中使用。在本例中,管道将首先使用LoadDataFromHub
步骤加载orca_dpo_pairs
数据集,然后应用ShuffleStep
。
# "shuffle_step.py"
from typing import TYPE_CHECKING, List
from distilabel.steps import GlobalStep, StepInput
if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput
import random
class ShuffleStep(GlobalStep):
@property
def inputs(self) -> List[str]:
return ["instruction", "chosen", "rejected"]
@property
def outputs(self) -> List[str]:
return ["instruction", "generations", "order"]
def process(self, inputs: StepInput) -> "StepOutput":
outputs = []
for input in inputs:
chosen = input["chosen"]
rejected = input["rejected"]
pair = [chosen, rejected]
random.shuffle(pair)
order = ["chosen" if x == chosen else "rejected" for x in pair]
outputs.append({"instruction": input["instruction"], "generations": pair, "order": order})
yield outputs
from shuffle_step import ShuffleStep
定义管道
为了清理现有的偏好数据集,我们需要定义一个包含所有必要步骤的Pipeline
。但是,类似的工作流程可用于清理 SFT 数据集。下面,我们将详细介绍每个步骤。
加载数据集
我们将使用刚刚洗牌的数据集作为源数据。
- 组件:
LoadDataFromDicts
- 输入列:
system
、question
、chosen
、rejected
、generations
和order
,与加载的字典列表中的键相同。 - 输出列:
system
、instruction
、chosen
、rejected
、generations
和order
。我们将使用output_mappings
来重命名列。
load_dataset = LoadDataFromDicts(
data=dataset[:1],
output_mappings={"question": "instruction"},
pipeline=Pipeline(name="showcase-pipeline"),
)
load_dataset.load()
next(load_dataset.process())
评估响应
为了评估响应的质量,我们将使用meta-llama/Meta-Llama-3.1-70B-Instruct
,应用根据不同维度(有用性、诚实性、指令遵循、真实性)判断响应的UltraFeedback
任务。对于 SFT 数据集,您可以使用PrometheusEval
来代替。
- 组件:使用
InferenceEndpointsLLM
的带有 LLM 的UltraFeedback
任务 - 输入列:
instruction
、generations
- 输出列:
ratings
、rationales
、distilabel_metadata
、model_name
为了您的用例并改进结果,您可以使用任何您选择的其他 LLM。
evaluate_responses = UltraFeedback(
aspect="overall-rating",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
),
pipeline=Pipeline(name="showcase-pipeline"),
)
evaluate_responses.load()
next(
evaluate_responses.process(
[
{
"instruction": "What's the capital of Spain?",
"generations": ["Madrid", "Barcelona"],
}
]
)
)
仅保留所需的列
我们将删除不需要的列。
- 组件:
KeepColumns
- 输入列:
system
、instruction
、chosen
、rejected
、generations
、ratings
、rationales
、distilabel_metadata
和model_name
- 输出列:
instruction
、chosen
、rejected
、generations
和order
keep_columns = KeepColumns(
columns=[
"instruction",
"generations",
"order",
"ratings",
"rationales",
"model_name",
],
pipeline=Pipeline(name="showcase-pipeline"),
)
keep_columns.load()
next(
keep_columns.process(
[
{
"system": "",
"instruction": "What's the capital of Spain?",
"chosen": "Madrid",
"rejected": "Barcelona",
"generations": ["Madrid", "Barcelona"],
"order": ["chosen", "rejected"],
"ratings": [5, 1],
"rationales": ["", ""],
"model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct",
}
]
)
)
(可选) 进一步的数据整理
您可以使用 Argilla 进一步整理您的数据。
- 组件:
PreferenceToArgilla
步骤 - 输入列:
instruction
、generations
、generation_models
、ratings
- 输出列:
instruction
、generations
、generation_models
、ratings
to_argilla = PreferenceToArgilla(
dataset_name="cleaned-dataset",
dataset_workspace="argilla",
api_url="https://[your-owner-name]-[your-space-name].hf.space",
api_key="[your-api-key]",
num_generations=2,
)
运行管道
下面,您可以看到完整的管道定义
with Pipeline(name="clean-dataset") as pipeline:
load_dataset = LoadDataFromDicts(data=dataset, output_mappings={"question": "instruction"})
evaluate_responses = UltraFeedback(
aspect="overall-rating",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
),
)
keep_columns = KeepColumns(
columns=[
"instruction",
"generations",
"order",
"ratings",
"rationales",
"model_name",
]
)
to_argilla = PreferenceToArgilla(
dataset_name="cleaned-dataset",
dataset_workspace="argilla",
api_url="https://[your-owner-name]-[your-space-name].hf.space",
api_key="[your-api-key]",
num_generations=2,
)
load_dataset.connect(evaluate_responses)
evaluate_responses.connect(keep_columns)
keep_columns.connect(to_argilla)
现在让我们运行管道并清理我们的偏好数据集。
distiset = pipeline.run()
让我们检查一下!如果您已将数据加载到 Argilla,则可以在 Argilla UI 中开始标注。
您可以将数据集推送到 Hub 以与社区共享,并嵌入它来探索数据。
distiset.push_to_hub("[your-owner-name]/example-cleaned-preference-dataset")