PEFT 文档

快速教程

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

快速入门

PEFT 提供了参数高效的方法来微调大型预训练模型。传统的范式是为每个下游任务微调模型的所有参数,但由于当今模型中参数数量巨大,这种方法变得极其昂贵且不切实际。相反,训练更少量的提示参数,或者使用像低秩适应(LoRA)这样的重参数化方法来减少可训练参数的数量会更高效。

本快速入门将向您展示 PEFT 的主要功能,以及如何训练或在通常无法在消费级设备上运行的大型模型上进行推理。

训练

每种 PEFT 方法都由一个 PeftConfig 类定义,该类存储了构建 PeftModel 所需的所有重要参数。例如,要使用 LoRA 进行训练,加载并创建一个 LoraConfig 类,并指定以下参数:

  • task_type:训练任务的类型(本例中为序列到序列语言建模)
  • inference_mode:是否将模型用于推理
  • r:低秩矩阵的维度
  • lora_alpha:低秩矩阵的缩放因子
  • lora_dropout:LoRA 层的 dropout 概率
from peft import LoraConfig, TaskType

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,例如目标模块或偏置类型,请参阅 LoraConfig 参考。

一旦设置好 LoraConfig,就可以使用 get_peft_model() 函数创建一个 PeftModel。它接受一个基础模型——您可以从 Transformers 库加载——以及一个包含如何配置模型以进行 LoRA 训练的参数的 LoraConfig

加载您想要微调的基础模型。

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-large")

使用 get_peft_model() 函数将基础模型和 `peft_config` 包装起来,以创建一个 PeftModel。要了解模型中可训练参数的数量,请使用 `print_trainable_parameters` 方法。

from peft import get_peft_model

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
"output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282"

bigscience/mt0-large 的 12 亿个参数中,您只训练了其中的 0.19%!

就是这样 🎉!现在您可以使用 Transformers 的 `Trainer`、Accelerate 或任何自定义的 PyTorch 训练循环来训练模型。

例如,要使用 `Trainer` 类进行训练,需要设置一个包含一些训练超参数的 `TrainingArguments` 类。

training_args = TrainingArguments(
    output_dir="your-name/bigscience/mt0-large-lora",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

将模型、训练参数、数据集、分词器以及任何其他必要组件传递给 `Trainer`,然后调用 `train` 开始训练。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

保存模型

模型训练完成后,您可以使用 `save_pretrained` 函数将模型保存到目录中。

model.save_pretrained("output_dir")

您也可以使用 `push_to_hub` 函数将模型保存到 Hub(首先请确保您已登录您的 Hugging Face 账户)。

from huggingface_hub import notebook_login

notebook_login()
model.push_to_hub("your-name/bigscience/mt0-large-lora")

这两种方法都只保存经过训练的额外 PEFT 权重,这意味着存储、传输和加载都非常高效。例如,这个用 LoRA 训练的 facebook/opt-350m 模型只包含两个文件:`adapter_config.json` 和 `adapter_model.safetensors`。`adapter_model.safetensors` 文件仅为 6.3MB!

存储在 Hub 上的 opt-350m 模型的适配器权重只有约 6MB,而完整模型权重的大小可能约为 700MB。

推理

有关可用 `AutoPeftModel` 类的完整列表,请参阅 AutoPeftModel API 参考。

使用 AutoPeftModel 类和 `from_pretrained` 方法,可以轻松加载任何经过 PEFT 训练的模型进行推理。

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
import torch

model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

model = model.to("cuda")
model.eval()
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=50)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

"Preheat the oven to 350 degrees and place the cookie dough in the center of the oven. In a large bowl, combine the flour, baking powder, baking soda, salt, and cinnamon. In a separate bowl, combine the egg yolks, sugar, and vanilla."

对于其他未被 `AutoPeftModelFor` 类明确支持的任务——例如自动语音识别——您仍然可以使用基础的 AutoPeftModel 类来为该任务加载模型。

from peft import AutoPeftModel

model = AutoPeftModel.from_pretrained("smangrul/openai-whisper-large-v2-LORA-colab")

后续步骤

现在您已经了解了如何使用其中一种 PEFT 方法训练模型,我们鼓励您尝试其他一些方法,比如提示微调。步骤与本快速入门中展示的非常相似:

  1. 为 PEFT 方法准备一个 PeftConfig
  2. 使用 get_peft_model() 方法从配置和基础模型中创建一个 PeftModel

然后您可以随心所欲地训练它!要加载 PEFT 模型进行推理,您可以使用 AutoPeftModel 类。

如果您有兴趣为特定任务(如语义分割、多语言自动语音识别、DreamBooth、词元分类等)使用另一种 PEFT 方法训练模型,也欢迎查阅任务指南。

< > 在 GitHub 上更新