TRL 文档

Unsloth 集成

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Unsloth 集成

此部分正在建设中。欢迎贡献!

Unsloth 是一个用于微调和强化学习的开源框架,它能将 LLM(如 Llama、Mistral、Gemma、DeepSeek 等)的训练速度提升高达 2 倍,同时减少高达 70% 的 VRAM 占用,并为训练、评估和部署提供了一个精简的、与 Hugging Face 兼容的工作流程。Unsloth 库与 SFTTrainer 完全兼容。下面列出了在 1 x A100 上的部分基准测试结果

1 A100 40GB 数据集 🤗 🤗 + FlashAttention 2 🦥 Unsloth 🦥 节省的 VRAM
Code Llama 34b Slim Orca 1 倍 1.01x 1.94x -22.7%
Llama-2 7b Slim Orca 1 倍 0.96x 1.87x -39.3%
Mistral 7b Slim Orca 1 倍 1.17x 1.88x -65.9%
Tiny Llama 1.1b Alpaca 1 倍 1.55x 2.74x -57.8%

首先,根据官方文档安装 unsloth。安装后,你可以非常简单地将 unsloth 集成到你的工作流程中;你只需加载一个 FastLanguageModel,而不是加载 AutoModelForCausalLM,如下所示

import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/mistral-7b",
    max_seq_length=max_length,
    dtype=None,  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit=True,  # Use 4bit quantization to reduce memory usage. Can be False
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  # Dropout = 0 is currently optimized
    bias="none",  # Bias = "none" is currently optimized
    use_gradient_checkpointing=True,
    random_state=3407,
)

training_args = SFTConfig(output_dir="./output", max_length=max_length)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

保存的模型与 Hugging Face 的 transformers 库完全兼容。在他们的官方仓库中了解更多关于 unsloth 的信息。

< > 在 GitHub 上更新