TRL 文档

SFT 训练器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

SFT 训练器

All_models-SFT-blue smol_course-Chapter_1-yellow

概览

TRL 支持用于训练语言模型的监督式微调 (Supervised Fine-Tuning, SFT) 训练器。

这种后训练方法由 Younes Belkada 贡献。

快速入门

本示例演示了如何使用 TRL 中的 SFTTrainer 来训练语言模型。我们将在 Capybara 数据集上训练一个 Qwen 3 0.6B 模型,这是一个紧凑、多样化的多轮对话数据集,用于基准测试推理和泛化能力。

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

预期的数据集类型和格式

SFT 支持语言建模提示-补全两种类型的数据集。SFTTrainer 兼容标准对话式两种数据集格式。当提供对话式数据集时,训练器会自动将聊天模板应用于数据集。

# Standard language modeling
{"text": "The sky is blue."}

# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}

# Standard prompt-completion
{"prompt": "The sky is",
 "completion": " blue."}

# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
 "completion": [{"role": "assistant", "content": "It is blue."}]}

如果你的数据集不属于这些格式之一,你可以对其进行预处理,将其转换为预期格式。以下是使用 FreedomIntelligence/medical-o1-reasoning-SFT 数据集的一个示例。

from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

def preprocess_function(example):
    return {
        "prompt": [{"role": "user", "content": example["Question"]}],
        "completion": [
            {"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
        ],
    }

dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
{
    "prompt": [
        {
            "content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
            "role": "user",
        }
    ],
    "completion": [
        {
            "content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
            "role": "assistant",
        }
    ],
}

深入探究 SFT 方法

监督式微调(SFT)是使语言模型适应目标数据集的最简单、最常用的方法。该模型以完全监督的方式,使用输入和输出序列对进行训练。目标是最小化目标序列的负对数似然(NLL),并以输入为条件。

本节将分解 SFT 在实践中如何工作,涵盖关键步骤:**预处理**、**分词**和**损失计算**。

预处理和分词

在训练期间,根据数据集格式,每个示例预计包含一个**文本字段**或一个**(提示,补全)**对。有关预期格式的更多详细信息,请参阅数据集格式SFTTrainer 使用模型的分词器对每个输入进行分词。如果提示和补全是分开提供的,它们会在分词前被拼接起来。

计算损失

sft_figure

SFT 中使用的损失是词元级交叉熵损失,定义为LSFT(θ)=t=1Tlogpθ(yty<t), \mathcal{L}_{\text{SFT}}(\theta) = - \sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}),

其中其中 yt y_t 是时间步t t 的目标词元,模型被训练来预测给定前面所有词元的下一个词元。在实践中,填充词元在损失计算中被掩码掉。

标签移位和掩码

在训练期间,损失是使用**单词元移位**计算的:模型被训练来基于所有先前的词元预测序列中的每个词元。具体来说,输入序列向右移动一个位置以形成目标标签。填充词元(如果存在)通过在相应位置应用忽略索引(默认为 -100)在损失计算中被忽略。这确保了损失只关注有意义的、非填充的词元。

日志指标

  • global_step:到目前为止已执行的优化器步骤总数。
  • epoch:当前的 epoch 数,基于数据集的迭代。
  • num_tokens:到目前为止已处理的词元总数。
  • loss:在当前日志记录间隔内,对非掩码词元计算的平均交叉熵损失。
  • mean_token_accuracy:模型的 top-1 预测与真实词元匹配的非掩码词元的比例。
  • learning_rate:当前学习率,如果使用调度器,可能会动态变化。
  • grad_norm:梯度的 L2 范数,在梯度裁剪之前计算。

自定义

模型初始化

你可以直接将 `from_pretrained()` 方法的关键字参数传递给 SFTConfig。例如,如果你想以不同的精度加载模型,类似于

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)

你可以通过将 `model_init_kwargs={"torch_dtype": torch.bfloat16}` 参数传递给 SFTConfig 来实现。

from trl import SFTConfig

training_args = SFTConfig(
    model_init_kwargs={"torch_dtype": torch.bfloat16},
)

请注意,`from_pretrained()` 的所有关键字参数都受支持。

打包

SFTTrainer 支持*示例打包*,即在同一个输入序列中打包多个示例以提高训练效率。要启用打包功能,只需在 SFTConfig 构造函数中传递 `packing=True`。

training_args = SFTConfig(packing=True)

有关打包的更多详细信息,请参阅打包

只在助手消息上训练

要只在助手消息上训练,请使用一个对话式数据集,并在 SFTConfig 中设置 `assistant_only_loss=True`。此设置确保损失**只**在助手回复上计算,而忽略用户或系统消息。

training_args = SFTConfig(assistant_only_loss=True)

train_on_assistant

此功能仅适用于支持通过 `{% generation %}` 和 `{% endgeneration %}` 关键字返回助手词元掩码的聊天模板。有关此类模板的示例,请参阅 HugggingFaceTB/SmolLM3-3B

只在补全部分训练

要只在补全部分训练,请使用提示-补全数据集。默认情况下,训练器仅在补全词元上计算损失,忽略提示词元。如果你想在完整序列上训练,请在 SFTConfig 中设置 `completion_only_loss=False`。

train_on_completion

只在补全部分训练与只在助手消息上训练兼容。在这种情况下,请使用[对话式](dataset_formats#conversational)[提示-补全](dataset_formats#prompt-completion)数据集,并在 [SFTConfig](/docs/trl/v0.21.0/en/sft_trainer#trl.SFTConfig) 中设置 `assistant_only_loss=True`。

使用 PEFT 训练适配器

我们支持与 🤗 PEFT 库的紧密集成,允许任何用户方便地训练适配器并在 Hub 上分享,而不是训练整个模型。

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    "Qwen/Qwen3-0.6B",
    train_dataset=dataset,
    peft_config=LoraConfig()
)

trainer.train()

你也可以继续训练你的 `peft.PeftModel`。为此,首先在 SFTTrainer 外部加载一个 `PeftModel`,然后将其直接传递给训练器,而不需要传递 `peft_config` 参数。

from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
)

trainer.train()

训练适配器时,通常使用更高的学习率(≈1e-4),因为只学习新的参数。

SFTConfig(learning_rate=1e-4, ...)

使用 Liger Kernel 进行训练

Liger Kernel 是一系列用于 LLM 训练的 Triton 内核,可将多 GPU 吞吐量提升 20%,内存使用量减少 60%(支持高达 4 倍的上下文长度),并与 FlashAttention、PyTorch FSDP 和 DeepSpeed 等工具无缝协作。更多信息,请参阅Liger Kernel 集成

使用 Unsloth 进行训练

Unsloth 是一个开源的微调和强化学习框架,可以使 LLMs(如 Llama、Mistral、Gemma、DeepSeek 等)的训练速度提高 2 倍,VRAM 使用量减少高达 70%,同时为训练、评估和部署提供了简化的、与 Hugging Face 兼容的工作流程。更多信息,请参阅Unsloth 集成

指令调优示例

指令调优教导基础语言模型遵循用户指令并进行对话。这需要

  1. 聊天模板:定义如何将对话构造成文本序列,包括角色标记(用户/助手)、特殊词元和对话轮次边界。在聊天模板中阅读更多关于聊天模板的信息。
  2. 对话数据集:包含指令-响应对

此示例展示了如何使用 Capybara 数据集和来自 HuggingFaceTB/SmolLM3-3B 的聊天模板,将 Qwen 3 0.6B Base 模型转换为一个指令遵循模型。SFT 训练器会自动处理分词器更新和特殊词元配置。

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B-Base",
    args=SFTConfig(
        output_dir="Qwen3-0.6B-Instruct",
        chat_template_path="HuggingFaceTB/SmolLM3-3B",
    ),
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

一些基础模型,如 Qwen 系列模型,在其分词器中预定义了聊天模板。在这些情况下,不必应用 `clone_chat_template()`,因为分词器已经处理了格式化。但是,有必要将 EOS 词元与聊天模板对齐,以确保模型的响应正确终止。在这些情况下,在 SFTConfig 中指定 `eos_token`;例如,对于 `Qwen/Qwen2.5-1.5B`,应设置 `eos_token="<|im_end|>"`。

训练完成后,你的模型现在可以使用其新的聊天模板来遵循指令并进行对话。

>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'

或者,使用结构化对话格式(推荐)

>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]

使用 SFT 进行工具调用

SFT 训练器完全支持对具有*工具调用*能力的模型进行微调。在这种情况下,每个数据集示例应包含

  • 对话消息,包括任何工具调用(`tool_calls`)和工具响应(`tool` 角色消息)
  • `tools` 列中的可用工具列表,通常以 JSON 模式提供

有关预期数据集结构的详细信息,请参阅数据集格式 — 工具调用部分。

为视觉语言模型扩展 SFTTrainer

SFTTrainer 目前尚未原生支持视觉语言数据。但是,我们提供了一个关于如何调整训练器以支持视觉语言数据的指南。具体来说,您需要使用一个与视觉语言数据兼容的自定义数据整理器。本指南概述了进行这些调整的步骤。有关具体示例,请参阅脚本 examples/scripts/sft_vlm.py,该脚本演示了如何在 HuggingFaceH4/llava-instruct-mix-vsft 数据集上微调 LLaVA 1.5 模型。

准备数据

数据格式是灵活的,只要它与我们稍后将定义的自定义整理器兼容即可。一种常见的方法是使用对话数据。鉴于数据包含文本和图像,格式需要相应调整。以下是一个涉及文本和图像的对话数据格式示例

images = ["obama.png"]
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Who is this?"},
            {"type": "image"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "Barack Obama"}
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is he famous for?"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "He is the 44th President of the United States."}
        ]
    }
]

为了说明如何使用 LLaVA 模型处理此数据格式,您可以使用以下代码

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))

输出将格式化如下

Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States. 

用于处理多模态数据的自定义整理器

SFTTrainer 的默认行为不同,多模态数据的处理是在数据整理过程中动态完成的。为此,您需要定义一个自定义整理器来处理文本和图像。该整理器必须接受一个示例列表作为输入(有关数据格式的示例,请参见上一节)并返回一批处理过的数据。以下是此类整理器的一个示例

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"][0] for example in examples]

    # Tokenize the texts and process the images
    batch = processor(images=images, text=texts, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch["labels"] = labels

    return batch

我们可以通过运行以下代码来验证整理器是否按预期工作

from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]]  # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys())  # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])

训练视觉语言模型

现在我们已经准备好数据并定义了整理器,我们可以继续训练模型了。为了确保数据不被仅作为文本处理,我们需要在 SFTConfig 中设置几个参数,特别是将 `remove_unused_columns` 和 `skip_prepare_dataset` 设置为 `True` 以避免数据集的默认处理。以下是如何设置 `SFTTrainer` 的示例。

training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_dataset,
    processing_class=processor,
)

有关在 HuggingFaceH4/llava-instruct-mix-vsft 数据集上训练 LLaVa 1.5 的完整示例,请参阅脚本 examples/scripts/sft_vlm.py

SFTTrainer

class trl.SFTTrainer

< >

( model: typing.Union[str, torch.nn.modules.module.Module, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.sft_config.SFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_loss_func: typing.Optional[typing.Callable] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) optimizer_cls_and_kwargs: typing.Optional[tuple[type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]]] = None preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Optional[typing.Callable[[dict], str]] = None )

参数

  • model (Union[str, PreTrainedModel]) — 要训练的模型。可以是:

    • 一个字符串,即 huggingface.co 上模型仓库中预训练模型的*模型 ID*,或包含使用 `save_pretrained` 保存的模型权重的*目录*路径,例如 `'./my_model_directory/'`。模型使用 `from_pretrained` 和 `args.model_init_kwargs` 中的关键字参数加载。
    • 一个 PreTrainedModel 对象。仅支持因果语言模型。
  • args (SFTConfig, *可选*, 默认为 None) — 此训练器的配置。如果为 None,则使用默认配置。
  • data_collator (DataCollator, *可选*) — 用于从处理过的 `train_dataset` 或 `eval_dataset` 的元素列表中形成批次的函数。将默认为自定义的 `DataCollatorForLanguageModeling`。
  • train_dataset (DatasetIterableDataset) — 用于训练的数据集。SFT 支持语言建模类型和提示-补全类型。样本的格式可以是:

    • 标准:每个样本包含纯文本。
    • 对话式:每个样本包含结构化消息(例如,角色和内容)。

    训练器还支持已处理(已分词)的数据集,只要它们包含一个 `input_ids` 字段。

  • eval_dataset (DatasetIterableDataset 或 `dict[str, Union[Dataset, IterableDataset]]`) — 用于评估的数据集。它必须满足与 `train_dataset` 相同的要求。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, *可选*, 默认为 None) — 用于处理数据的处理类。如果为 None,则从模型的名称使用 from_pretrained 加载处理类。
  • callbacks (TrainerCallback 列表, *可选*, 默认为 None) — 用于自定义训练循环的回调列表。将添加到详见此处的默认回调列表中。

    如果要删除使用的默认回调之一,请使用 `remove_callback` 方法。

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], *可选*, 默认为 (None, None)) — 包含要使用的优化器和调度器的元组。将默认为模型上的 `AdamW` 实例和由 `args` 控制的 `get_linear_schedule_with_warmup` 提供的调度器。
  • optimizer_cls_and_kwargs (Tuple[Type[torch.optim.Optimizer], Dict[str, Any]], *可选*, 默认为 None) — 包含优化器类和要使用的关键字参数的元组。覆盖 `args` 中的 `optim` 和 `optim_args`。与 `optimizers` 参数不兼容。

    与 `optimizers` 不同,此参数避免了在初始化 Trainer 之前将模型参数放置在正确设备上的需要。

  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], *可选*, 默认为 None) — 一个函数,用于在每个评估步骤缓存 logits 之前对其进行预处理。必须接受两个张量,logits 和标签,并返回处理后所需的 logits。此函数所做的修改将反映在 `compute_metrics` 接收的预测中。

    请注意,如果数据集没有标签,则标签(第二个参数)将为 `None`。

  • peft_config (~peft.PeftConfig, *可选*, 默认为 None) — 用于包装模型的 PEFT 配置。如果为 `None`,则不包装模型。
  • formatting_func (Optional[Callable]) — 在分词前应用于数据集的格式化函数。显式应用格式化函数会将数据集转换为语言建模类型。

用于监督式微调(SFT)方法的训练器。

此类是 `transformers.Trainer` 类的包装器,并继承其所有属性和方法。

示例

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
trainer.train()

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

参数

  • resume_from_checkpoint (str or bool, optional) — 如果是 str,则为先前 Trainer 实例保存的检查点的本地路径。如果是 bool 且等于 True,则加载先前 Trainer 实例保存在 args.output_dir 中的最后一个检查点。如果提供此参数,训练将从加载的模型/优化器/调度器状态恢复。
  • trial (optuna.Trialdict[str, Any], optional) — 用于超参数搜索的试验运行或超参数字典。
  • ignore_keys_for_eval (list[str], optional) — 模型输出(如果为字典)中的一个键列表,在训练期间收集评估预测时应忽略这些键。
  • kwargs (dict[str, Any], optional) — 用于隐藏已弃用参数的附加关键字参数。

主训练入口点。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。

仅从主进程保存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

参数

  • commit_message (str, optional, 默认为 "End of training") — 推送时要提交的消息。
  • blocking (bool, optional, 默认为 True) — 函数是否仅在 git push 完成后返回。
  • token (str, optional, 默认为 None) — 具有写入权限的令牌,用于覆盖 Trainer 的原始参数。
  • revision (str, optional) — 要提交的 git 修订版本。默认为“main”分支的头部。
  • kwargs (dict[str, Any], optional) — 传递给 ~Trainer.create_model_card 的附加关键字参数。

将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。

SFTConfig

class trl.SFTConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None chat_template_path: typing.Optional[str] = None dataset_text_field: str = 'text' dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None eos_token: typing.Optional[str] = None pad_token: typing.Optional[str] = None max_length: typing.Optional[int] = 1024 packing: bool = False packing_strategy: str = 'bfd' padding_free: bool = False pad_to_multiple_of: typing.Optional[int] = None eval_packing: typing.Optional[bool] = None completion_only_loss: typing.Optional[bool] = None assistant_only_loss: bool = False activation_offloading: bool = False )

控制模型的参数

  • model_init_kwargs (dict[str, Any]None, optional, 默认为 None) — 当 SFTTrainermodel 参数以字符串形式提供时,用于 from_pretrained 的关键字参数。
  • chat_template_path (strNone, optional, 默认为 None) — 如果指定,则设置模型的聊天模板。这可以是一个分词器(本地目录或 Hugging Face Hub 模型)的路径,也可以是一个 Jinja 模板文件的直接路径。使用 Jinja 文件时,必须确保模板中引用的任何特殊令牌都已添加到分词器中,并相应地调整模型的嵌入层大小。

控制数据预处理的参数

  • dataset_text_field (str, optional, 默认为 "text") — 数据集中包含文本数据的列名。
  • dataset_kwargs (dict[str, Any]None, optional, 默认为 None) — 数据集准备的可选关键字参数字典。唯一支持的键是 skip_prepare_dataset
  • dataset_num_proc (intNone, optional, 默认为 None) — 用于处理数据集的进程数。
  • eos_token (strNone, optional, 默认为 None) — 用于指示一轮对话或序列结束的令牌。如果为 None,则默认为 processing_class.eos_token
  • pad_token (intNone, optional, 默认为 None) — 用于填充的令牌。如果为 None,则默认为 processing_class.pad_token,如果该值也为 None,则回退到 processing_class.eos_token
  • max_length (intNone, optional, 默认为 1024) — 标记化序列的最大长度。超过 max_length 的序列将从右侧截断。如果为 None,则不应用截断。启用打包时,此值设置序列长度。
  • packing (bool, optional, 默认为 False) — 是否将多个序列分组到固定长度的块中,以提高计算效率并减少填充。使用 max_length 定义序列长度。
  • packing_strategy (str, optional, 默认为 "bfd") — 打包序列的策略。可以是 "bfd"(最佳拟合递减,默认值)或 "wrapped"
  • padding_free (bool, optional, 默认为 False) — 是否通过将批次中的所有序列展平为单个连续序列来执行无填充的前向传播。这通过消除填充开销来减少内存使用。目前,这仅在 FlashAttention 2 或 3 中受支持,因为它们可以高效处理展平的批次结构。当使用 "bfd" 策略启用打包时,无论此参数的值如何,都会启用无填充。
  • pad_to_multiple_of (intNone, optional, 默认为 None) — 如果设置,序列将被填充到该值的倍数。
  • eval_packing (boolNone, optional, 默认为 None) — 是否打包评估数据集。如果为 None,则使用与 packing 相同的值。

控制训练的参数

  • completion_only_loss (boolNone, optional, 默认为 None) — 是否仅对序列的补全部分计算损失。如果设置为 True,则仅对补全部分计算损失,这仅支持提示-补全数据集。如果为 False,则对整个序列计算损失。如果为 None(默认),行为取决于数据集:对于提示-补全数据集,对补全部分计算损失;对于语言建模数据集,对整个序列计算损失。
  • assistant_only_loss (bool, optional, 默认为 False) — 是否仅对序列的助手部分计算损失。如果设置为 True,则仅对助手响应计算损失,这仅支持对话数据集。如果为 False,则对整个序列计算损失。
  • activation_offloading (bool, optional, 默认为 False) — 是否将激活卸载到 CPU。

用于 SFTTrainer 的配置类。

此类仅包含特定于 SFT 训练的参数。有关训练参数的完整列表,请参阅 TrainingArguments 文档。请注意,此类中的默认值可能与 TrainingArguments 中的不同。

使用 HfArgumentParser,我们可以将此类别转换为可在命令行上指定的 argparse 参数。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.