使用 ORPO 微调 Llama 3

社区文章 发布于 2024 年 4 月 22 日

ORPO 是一种**激动人心的新微调技术**,它将传统的监督微调和偏好对齐阶段结合到一个单一过程中。这减少了训练所需的计算资源和时间。此外,实证结果表明,ORPO 在各种模型尺寸和基准上均优于其他对齐方法。

在本文中,我们将使用 TRL 库通过 ORPO 微调新的 Llama 3 8B 模型。代码可在 Google Colab 和 GitHub 上的 LLM 课程 中获取。

⚖️ ORPO

指令调优和偏好对齐是使大型语言模型 (LLM) 适应特定任务的基本技术。传统上,这涉及一个多阶段过程:1/ 在指令上进行**监督微调** (SFT),以使模型适应目标领域;然后 2/ 使用**偏好对齐方法**,如带有人类反馈的强化学习 (RLHF) 或直接偏好优化 (DPO),以增加生成首选响应而非拒绝响应的可能性。

然而,研究人员发现了这种方法的局限性。虽然 SFT 能有效地使模型适应所需的领域,但它无意中**增加了生成不良答案的概率**,同时生成了首选答案。这就是为什么偏好对齐阶段对于扩大首选输出和被拒绝输出的可能性之间的差距是必要的。

请注意,在监督微调期间,拒绝响应的概率是如何增加的(来自 ORPO 论文)。

Hong 和 Lee (2024) 引入的 ORPO 通过将指令调优和偏好对齐结合到一个单一的、整体的训练过程中,为这个问题提供了一个优雅的解决方案。ORPO 修改了标准语言建模目标,将负对数似然损失与赔率比 (OR) 项结合起来。这种 OR 损失对被拒绝的响应进行弱惩罚,同时强烈奖励首选的响应,从而使模型能够同时学习目标任务并与人类偏好保持一致。

LORPO=E(x,yw,yl)[LSFT+λLOR] ORPO 已在主要的微调库中实现,例如 TRLAxolotlLLaMA-Factory。在下一节中,我们将看到如何在 TRL 中使用它。

💻 使用 ORPO 微调 Llama 3

Llama 3 是 Meta 开发的最新 LLM 家族。这些模型在一个包含 **15 万亿个 token** 的庞大数据集上进行训练(相比之下,Llama 2 为 2 万亿个 token)。已发布了两种模型大小:一个 700 亿参数模型和一个较小的 80 亿参数模型。700 亿参数模型已经表现出令人印象深刻的性能,在 MMLU 基准测试中得分 82,在 HumanEval 基准测试中得分 81.7。

Llama 3 模型还将上下文长度增加到 8,192 个 token(Llama 2 为 4,096 个 token),并可能通过 RoPE 扩展到 32k。此外,这些模型使用了一个拥有 128K token 词汇量的新 tokenizer,将编码文本所需的 token 数量减少了 15%。这种词汇量也解释了从 7B 到 8B 参数的提升。

来自 ORPO-DPO-mix-40k 的样本。

ORPO 需要一个偏好数据集,包括一个提示、一个选择的答案和一个被拒绝的答案。在此示例中,我们将使用 mlabonne/orpo-dpo-mix-40k,它是以下高质量 DPO 数据集的组合

感谢 argillaunalignmentM4-aijondurbin 提供了源数据集。

像往常一样,我们首先安装所需的库

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们可以导入必要的库并登录 W&B(可选)

import gc
import os

import torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

如果您拥有最新的 GPU,您还应该能够使用 Flash Attention 库 来替换默认的急切注意力实现,以获得更高效的实现。

if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16

接下来,我们将使用 bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 PEFT 为 QLoRA 设置 LoRA 配置。我还使用了便捷的 setup_chat_format() 函数来修改模型和 tokenizer,以支持 ChatML。它会自动应用此聊天模板,添加特殊 token,并调整模型的嵌入层大小以匹配新的词汇大小。

请注意,您需要提交请求才能访问 meta-llama/Meta-Llama-3-8B 并登录您的 Hugging Face 帐户。或者,您可以加载未受限的模型副本,例如 NousResearch/Meta--Llama-3-8B

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

模型准备好训练后,我们就可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将“chosen”和“rejected”列转换为 ChatML 格式。请注意,我只使用了 1,000 个样本,而不是整个数据集,因为运行时间太长。

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(1000))

def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置一些超参数

  • learning_rate: ORPO 使用的学习率比传统 SFT 甚至 DPO 都低得多。这个 8e-6 的值来自原始论文,大致对应于 SFT 学习率 1e-5 和 DPO 学习率 5e-6。对于实际的微调,我建议将其提高到 1e-6 左右。
  • beta:它是论文中的 $\lambda$ 参数,默认值为 0.1。原始论文的附录通过消融研究展示了它的选择方式。
  • 其他参数,如 max_length 和批处理大小,被设置为尽可能多地使用 VRAM(在此配置中约为 20 GB)。理想情况下,我们将训练模型 3-5 个 epoch,但这里我们只训练 1 个 epoch。

最后,我们可以使用 ORPOTrainer 训练模型,它充当一个包装器。

orpo_args = ORPOConfig(
    learning_rate=8e-6,
    beta=0.1,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results/",
)

trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

在 L4 GPU 上训练这 1,000 个样本大约需要 2 小时。让我们检查一下 W&B 图表

虽然损失下降了,但选中答案和拒绝答案之间的差异并不明显:平均裕度和准确度分别略高于零和 0.5。

在原始论文中,作者在 Anthropic/hh-rlhf 数据集(161k 样本)上训练了模型 10 个 epoch,这比我们快速运行的时间长得多。他们还对 Llama 3 进行了实验,并好心地 与我分享了他们的日志(感谢 Jiwoo Hong)。

为了结束本教程,让我们将 QLoRA 适配器与基本模型合并并将其推送到 Hugging Face Hub。

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

恭喜,我们完成了 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B。您可以使用此 Hugging Face Space 进行试玩(这是一个 notebook,可以制作自己的)。尽管模型训练不足,正如 W&B 曲线所强调的那样,我使用 LLM AutoEval 对 Nous 的基准测试套件进行了一些评估。

模型 平均分 AGIEval GPT4All TruthfulQA Bigbench
teknium/OpenHermes-2.5-Mistral-7B 📄 52.42 42.75 72.99 52.99 40.94
meta-llama/Meta-Llama-3-8B-Instruct 📄 51.34 41.22 69.86 51.65 42.64
mistralai/Mistral-7B-Instruct-v0.1 📄 49.15 33.36 67.87 55.89 39.48
mlabonne/OrpoLlama-3-8B 📄 46.76 31.56 70.19 48.11 37.17
meta-llama/Meta-Llama-3-8B 📄 45.42 31.1 69.95 43.91 36.7

我们的 ORPO 微调实际上相当不错,并改进了基本模型在所有基准上的性能。这令人鼓舞,可能意味着对全部 4 万个样本进行微调将产生出色的结果。

对于开源社区来说,这是一个激动人心的时刻,越来越多的高质量开放权重模型正在发布。闭源模型和开放权重模型之间的差距正在慢慢缩小,微调是为您的用例获得最佳性能的重要工具。

image/png

结论

在本文中,我们介绍了 ORPO 算法,并解释了它如何将 SFT 和偏好对齐阶段统一为一个单一过程。然后,我们使用 TRL 在自定义偏好数据集上微调了 Llama 3 8B 模型。最终模型显示出令人鼓舞的结果,并突出了 ORPO 作为一种新微调范式的潜力。

希望这能有所帮助,我建议运行 Colab 笔记本 来微调您自己的 Llama 3 模型。在未来的文章中,我们将探讨如何创建高质量数据集——这是一个经常被忽视的问题。如果您喜欢这篇文章,请在 Hugging Face 和 Twitter @maximelabonne 上关注我。

参考文献

社区

注册登录 发表评论