TRL 文档

命令行界面 (CLIs)

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

命令行界面 (CLIs)

您可以使用 TRL 通过命令行界面 (CLI) 使用诸如监督式微调 (SFT) 或直接策略优化 (DPO) 等方法来微调您的语言模型。

目前支持的 CLI 有

训练命令

  • trl dpo:使用 DPO 微调 LLM
  • trl grpo:使用 GRPO 微调 LLM
  • trl kto:使用 KTO 微调 LLM
  • trl sft:使用 SFT 微调 LLM

其他命令

  • trl env:获取系统信息

使用 CLI 进行微调

开始之前,请从 Hugging Face Hub 中选择一个语言模型。可以在模型中使用过滤器 “text-generation” 找到支持的模型。此外,请确保为您的任务选择相关的数据集。

在使用 sftdpo 命令之前,请确保运行

accelerate config

并为您的训练设置(单/多 GPU、DeepSpeed 等)选择正确的配置。在运行任何 CLI 命令之前,请确保完成 accelerate config 的所有步骤。

我们还建议您传递 YAML 配置文件来配置您的训练协议。以下是 YAML 文件的简单示例,您可以将其用于使用 trl sft 命令训练您的模型。

model_name_or_path:
  Qwen/Qwen2.5-0.5B
dataset_name:
  stanfordnlp/imdb
report_to:
  none
learning_rate:
  0.0001
lr_scheduler_type:
  cosine

将该配置保存在 .yaml 文件中,并立即开始!示例 CLI 配置可在 examples/cli_configs/example_config.yaml 中找到。请注意,您可以通过显式地将参数传递给 CLI 来覆盖配置文件中的参数,例如从根文件夹

trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts

将强制使用 cosine_with_restarts 作为 lr_scheduler_type

支持的参数

我们支持来自 transformers.TrainingArguments 的所有参数,对于加载您的模型,我们支持来自 ~trl.ModelConfig 的所有参数

class trl.ModelConfig

< >

( model_name_or_path: typing.Optional[str] = None model_revision: str = 'main' torch_dtype: typing.Optional[str] = None trust_remote_code: bool = False attn_implementation: typing.Optional[str] = None use_peft: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: typing.Optional[list[str]] = None lora_modules_to_save: typing.Optional[list[str]] = None lora_task_type: str = 'CAUSAL_LM' use_rslora: bool = False use_dora: bool = False load_in_8bit: bool = False load_in_4bit: bool = False bnb_4bit_quant_type: str = 'nf4' use_bnb_nested_quant: bool = False )

参数

  • model_name_or_path (strNone, 可选, 默认为 None) — 用于权重初始化的模型检查点。
  • model_revision (str, 可选, 默认为 "main") — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID。
  • torch_dtype (Literal["auto", "bfloat16", "float16", "float32"]None, 可选, 默认为 None) — 覆盖默认的 torch.dtype 并在此 dtype 下加载模型。可能的值有:

    • "bfloat16": torch.bfloat16
    • "float16": torch.float16
    • "float32": torch.float32
    • "auto": 从模型的权重自动派生 dtype。
  • trust_remote_code (bool, 可选, 默认为 False) — 是否允许使用 Hub 上定义的自定义模型及其自己的建模文件。此选项仅应为信任的仓库设置为 True,并在您已阅读代码的情况下设置,因为它将在您的本地计算机上执行 Hub 上存在的代码。
  • attn_implementation (strNone, 可选, 默认为 None) — 要使用的注意力实现。您可以运行 --attn_implementation=flash_attention_2,在这种情况下,您必须通过运行 pip install flash-attn --no-build-isolation 手动安装它。
  • use_peft (bool, 可选, 默认为 False) — 是否使用 PEFT 进行训练。
  • lora_r (int, 可选, 默认为 16) — LoRA R 值。
  • lora_alpha (int, 可选, 默认为 32) — LoRA alpha 值。
  • lora_dropout (float, 可选, 默认为 0.05) — LoRA dropout。
  • lora_target_modules (Union[str, list[str]]None, 可选, 默认为 None) — LoRA 目标模块。
  • lora_modules_to_save (list[str]None, 可选, 默认为 None) — 要解冻和训练的模型层。
  • lora_task_type (str, 可选, 默认为 "CAUSAL_LM") — 要为 LoRA 传递的任务类型(奖励建模请使用 "SEQ_CLS")。
  • use_rslora (bool, 可选, 默认为 False) — 是否使用 Rank-Stabilized LoRA,它将适配器缩放因子设置为 lora_alpha/√r,而不是原始默认值 lora_alpha/r
  • use_dora (bool可选,默认为 False) — 启用 Weight-Decomposed Low-Rank Adaptation (DoRA)。这项技术将权重更新分解为两个部分:幅度和方向。方向由普通的 LoRA 处理,而幅度则由一个单独的可学习参数处理。这可以提高 LoRA 的性能,尤其是在低秩的情况下。目前,DoRA 仅支持线性层和 Conv2D 层。DoRA 引入的开销比纯 LoRA 更大,因此建议合并权重以进行推理。
  • load_in_8bit (bool可选,默认为 False) — 是否对基础模型使用 8 位精度。仅适用于 LoRA。
  • load_in_4bit (bool可选,默认为 False) — 是否对基础模型使用 4 位精度。仅适用于 LoRA。
  • bnb_4bit_quant_type (str可选,默认为 "nf4") — 量化类型 ("fp4""nf4")。
  • use_bnb_nested_quant (bool可选,默认为 False) — 是否使用嵌套量化。

模型的配置类。

使用 HfArgumentParser,我们可以将此类转换为 argparse 参数,这些参数可以在命令行中指定。

您可以将这些参数中的任何一个传递给 CLI 或 YAML 文件。

监督式微调 (SFT)

按照上面的基本说明进行操作,并运行 trl sft --output_dir <output_dir> <*args>

trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb

SFT CLI 基于 trl/scripts/sft.py 脚本。

直接策略优化 (DPO)

要使用 DPO CLI,您需要拥有 TRL 格式的数据集,例如

这些数据集始终至少有三列:promptchosenrejected

  • prompt 是字符串列表。
  • chosen聊天格式 中选择的响应
  • rejected 是拒绝的响应,采用 聊天格式

要快速开始,您可以运行以下命令

trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style

DPO CLI 基于 trl/scripts/dpo.py 脚本。

自定义偏好数据集

将数据集格式化为 TRL 格式(您可以调整 examples/datasets/anthropic_hh.py

python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org

聊天界面

聊天界面已弃用,将在 TRL 0.19 中移除。请改用 transformers-cli chat。有关更多信息,请参阅 Transformers 文档,与文本生成模型聊天

聊天 CLI 可让您快速加载模型并与之对话。只需运行以下命令

$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat 
<quentin_gallouedec>:
What is the best programming language?

<Qwen/Qwen1.5-0.5B-Chat>:
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use   
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,  
and scalability. Ultimately, it depends on personal preference, needs, and goals.

请注意,聊天界面依赖于 tokenizer 的 聊天模板 来格式化模型的输入。请确保您的 tokenizer 已定义聊天模板。

除了与模型对话之外,您还可以使用一些命令

  • clear:清除当前对话并开始新的对话
  • example {NAME}:从配置中加载名为 {NAME} 的示例,并将其用作用户输入
  • set {SETTING_NAME}={SETTING_VALUE};:更改系统提示或生成设置(多个设置用 ; 分隔)。
  • reset:与 clear 相同,但如果生成配置已被 set 更改,则还会将生成配置重置为默认值
  • savesave {SAVE_NAME}:将当前聊天和设置保存到文件,默认保存到 ./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml{SAVE_NAME}(如果提供)
  • exit:关闭界面

获取系统信息

您可以通过运行以下命令来获取系统信息

trl env

这将打印出系统信息,包括 GPU 信息、CUDA 版本、PyTorch 版本、transformers 版本和 TRL 版本,以及任何已安装的可选依赖项。

Copy-paste the following information when reporting an issue:

- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config: 
  - compute_environment: LOCAL_MACHINE
  - distributed_type: DEEPSPEED
  - mixed_precision: no
  - use_cpu: False
  - debug: False
  - num_processes: 4
  - machine_rank: 0
  - num_machines: 1
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - enable_cpu_affinity: False
  - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0

报告问题时需要此信息。

< > 在 GitHub 上更新