TRL 文档
命令行界面 (CLIs)
并获取增强的文档体验
开始使用
命令行界面 (CLIs)
您可以使用 TRL 通过命令行界面 (CLI) 使用诸如监督式微调 (SFT) 或直接策略优化 (DPO) 等方法来微调您的语言模型。
目前支持的 CLI 有
训练命令
trl dpo
:使用 DPO 微调 LLMtrl grpo
:使用 GRPO 微调 LLMtrl kto
:使用 KTO 微调 LLMtrl sft
:使用 SFT 微调 LLM
其他命令
trl env
:获取系统信息
使用 CLI 进行微调
开始之前,请从 Hugging Face Hub 中选择一个语言模型。可以在模型中使用过滤器 “text-generation” 找到支持的模型。此外,请确保为您的任务选择相关的数据集。
在使用 sft
或 dpo
命令之前,请确保运行
accelerate config
并为您的训练设置(单/多 GPU、DeepSpeed 等)选择正确的配置。在运行任何 CLI 命令之前,请确保完成 accelerate config
的所有步骤。
我们还建议您传递 YAML 配置文件来配置您的训练协议。以下是 YAML 文件的简单示例,您可以将其用于使用 trl sft
命令训练您的模型。
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
将该配置保存在 .yaml
文件中,并立即开始!示例 CLI 配置可在 examples/cli_configs/example_config.yaml
中找到。请注意,您可以通过显式地将参数传递给 CLI 来覆盖配置文件中的参数,例如从根文件夹
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
将强制使用 cosine_with_restarts
作为 lr_scheduler_type
。
支持的参数
我们支持来自 transformers.TrainingArguments
的所有参数,对于加载您的模型,我们支持来自 ~trl.ModelConfig
的所有参数
class trl.ModelConfig
< source >( model_name_or_path: typing.Optional[str] = None model_revision: str = 'main' torch_dtype: typing.Optional[str] = None trust_remote_code: bool = False attn_implementation: typing.Optional[str] = None use_peft: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: typing.Optional[list[str]] = None lora_modules_to_save: typing.Optional[list[str]] = None lora_task_type: str = 'CAUSAL_LM' use_rslora: bool = False use_dora: bool = False load_in_8bit: bool = False load_in_4bit: bool = False bnb_4bit_quant_type: str = 'nf4' use_bnb_nested_quant: bool = False )
参数
- model_name_or_path (
str
或None
, 可选, 默认为None
) — 用于权重初始化的模型检查点。 - model_revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID。 - torch_dtype (
Literal["auto", "bfloat16", "float16", "float32"]
或None
, 可选, 默认为None
) — 覆盖默认的torch.dtype
并在此 dtype 下加载模型。可能的值有:"bfloat16"
:torch.bfloat16
"float16"
:torch.float16
"float32"
:torch.float32
"auto"
: 从模型的权重自动派生 dtype。
- trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上定义的自定义模型及其自己的建模文件。此选项仅应为信任的仓库设置为True
,并在您已阅读代码的情况下设置,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - attn_implementation (
str
或None
, 可选, 默认为None
) — 要使用的注意力实现。您可以运行--attn_implementation=flash_attention_2
,在这种情况下,您必须通过运行pip install flash-attn --no-build-isolation
手动安装它。 - use_peft (
bool
, 可选, 默认为False
) — 是否使用 PEFT 进行训练。 - lora_r (
int
, 可选, 默认为16
) — LoRA R 值。 - lora_alpha (
int
, 可选, 默认为32
) — LoRA alpha 值。 - lora_dropout (
float
, 可选, 默认为0.05
) — LoRA dropout。 - lora_target_modules (
Union[str, list[str]]
或None
, 可选, 默认为None
) — LoRA 目标模块。 - lora_modules_to_save (
list[str]
或None
, 可选, 默认为None
) — 要解冻和训练的模型层。 - lora_task_type (
str
, 可选, 默认为"CAUSAL_LM"
) — 要为 LoRA 传递的任务类型(奖励建模请使用"SEQ_CLS"
)。 - use_rslora (
bool
, 可选, 默认为False
) — 是否使用 Rank-Stabilized LoRA,它将适配器缩放因子设置为lora_alpha/√r
,而不是原始默认值lora_alpha/r
。 - use_dora (
bool
,可选,默认为False
) — 启用 Weight-Decomposed Low-Rank Adaptation (DoRA)。这项技术将权重更新分解为两个部分:幅度和方向。方向由普通的 LoRA 处理,而幅度则由一个单独的可学习参数处理。这可以提高 LoRA 的性能,尤其是在低秩的情况下。目前,DoRA 仅支持线性层和 Conv2D 层。DoRA 引入的开销比纯 LoRA 更大,因此建议合并权重以进行推理。 - load_in_8bit (
bool
,可选,默认为False
) — 是否对基础模型使用 8 位精度。仅适用于 LoRA。 - load_in_4bit (
bool
,可选,默认为False
) — 是否对基础模型使用 4 位精度。仅适用于 LoRA。 - bnb_4bit_quant_type (
str
,可选,默认为"nf4"
) — 量化类型 ("fp4"
或"nf4"
)。 - use_bnb_nested_quant (
bool
,可选,默认为False
) — 是否使用嵌套量化。
模型的配置类。
使用 HfArgumentParser
,我们可以将此类转换为 argparse 参数,这些参数可以在命令行中指定。
您可以将这些参数中的任何一个传递给 CLI 或 YAML 文件。
监督式微调 (SFT)
按照上面的基本说明进行操作,并运行 trl sft --output_dir <output_dir> <*args>
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
SFT CLI 基于 trl/scripts/sft.py
脚本。
直接策略优化 (DPO)
要使用 DPO CLI,您需要拥有 TRL 格式的数据集,例如
- TRL 的 Anthropic HH 数据集:https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
- TRL 的 OpenAI TL;DR 摘要数据集:https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
这些数据集始终至少有三列:prompt
、chosen
和 rejected
要快速开始,您可以运行以下命令
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
DPO CLI 基于 trl/scripts/dpo.py
脚本。
自定义偏好数据集
将数据集格式化为 TRL 格式(您可以调整 examples/datasets/anthropic_hh.py
)
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
聊天界面
聊天界面已弃用,将在 TRL 0.19 中移除。请改用 transformers-cli chat
。有关更多信息,请参阅 Transformers 文档,与文本生成模型聊天。
聊天 CLI 可让您快速加载模型并与之对话。只需运行以下命令
$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
<quentin_gallouedec>:
What is the best programming language?
<Qwen/Qwen1.5-0.5B-Chat>:
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
and scalability. Ultimately, it depends on personal preference, needs, and goals.
请注意,聊天界面依赖于 tokenizer 的 聊天模板 来格式化模型的输入。请确保您的 tokenizer 已定义聊天模板。
除了与模型对话之外,您还可以使用一些命令
clear
:清除当前对话并开始新的对话example {NAME}
:从配置中加载名为{NAME}
的示例,并将其用作用户输入set {SETTING_NAME}={SETTING_VALUE};
:更改系统提示或生成设置(多个设置用;
分隔)。reset
:与 clear 相同,但如果生成配置已被set
更改,则还会将生成配置重置为默认值save
或save {SAVE_NAME}
:将当前聊天和设置保存到文件,默认保存到./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml
或{SAVE_NAME}
(如果提供)exit
:关闭界面
获取系统信息
您可以通过运行以下命令来获取系统信息
trl env
这将打印出系统信息,包括 GPU 信息、CUDA 版本、PyTorch 版本、transformers 版本和 TRL 版本,以及任何已安装的可选依赖项。
Copy-paste the following information when reporting an issue: - Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31 - Python version: 3.11.9 - PyTorch version: 2.4.1 - CUDA device: NVIDIA H100 80GB HBM3 - Transformers version: 4.45.0.dev0 - Accelerate version: 0.34.2 - Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: DEEPSPEED - mixed_precision: no - use_cpu: False - debug: False - num_processes: 4 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - enable_cpu_affinity: False - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: [] - Datasets version: 3.0.0 - HF Hub version: 0.24.7 - TRL version: 0.12.0.dev0+acb4d70 - bitsandbytes version: 0.41.1 - DeepSpeed version: 0.15.1 - Diffusers version: 0.30.3 - Liger-Kernel version: 0.3.0 - LLM-Blender version: 0.0.2 - OpenAI version: 1.46.0 - PEFT version: 0.12.0
报告问题时需要此信息。
< > 在 GitHub 上更新