TRL 文档
Unsloth 集成
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
Unsloth 集成
此部分正在建设中。欢迎贡献!
Unsloth 是一个用于微调和强化学习的开源框架,它能将 LLM(如 Llama、Mistral、Gemma、DeepSeek 等)的训练速度提升高达 2 倍,同时减少高达 70% 的 VRAM 占用,并为训练、评估和部署提供了一个精简的、与 Hugging Face 兼容的工作流程。Unsloth 库与 SFTTrainer 完全兼容。下面列出了在 1 x A100 上的部分基准测试结果
1 A100 40GB | 数据集 | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 节省的 VRAM |
---|---|---|---|---|---|
Code Llama 34b | Slim Orca | 1 倍 | 1.01x | 1.94x | -22.7% |
Llama-2 7b | Slim Orca | 1 倍 | 0.96x | 1.87x | -39.3% |
Mistral 7b | Slim Orca | 1 倍 | 1.17x | 1.88x | -65.9% |
Tiny Llama 1.1b | Alpaca | 1 倍 | 1.55x | 2.74x | -57.8% |
首先,根据官方文档安装 unsloth
。安装后,你可以非常简单地将 unsloth 集成到你的工作流程中;你只需加载一个 FastLanguageModel
,而不是加载 AutoModelForCausalLM
,如下所示
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
保存的模型与 Hugging Face 的 transformers 库完全兼容。在他们的官方仓库中了解更多关于 unsloth 的信息。
< > 在 GitHub 上更新