TRL 文档

迭代训练器 (Iterative Trainer)

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

迭代训练器 (Iterative Trainer)

迭代式微调是一种训练方法,它允许在优化步骤之间执行自定义操作(例如生成和过滤)。在 TRL 中,我们提供了一个易于使用的 API,只需几行代码即可迭代地微调您的模型。

快速入门

要快速开始,您可以将模型标识符或预实例化的模型传递给训练器

from trl import IterativeSFTConfig, IterativeSFTTrainer

# Using a model identifier
trainer = IterativeSFTTrainer(
    "facebook/opt-350m",
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
)

# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

trainer = IterativeSFTTrainer(
    model,
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
    processing_class=tokenizer,
)

用法

IterativeSFTTrainer 支持两种向 step 函数提供输入数据的方式

使用张量列表作为输入:

inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}

trainer.step(**inputs)

使用字符串列表作为输入:

inputs = {
    "texts": texts,
    "texts_labels": texts_labels,  # Optional, defaults to texts
}

trainer.step(**inputs)

对于因果语言模型,标签将自动从 input_idstexts 创建。使用序列到序列模型时,您必须提供自己的标签或 text_labels

配置

IterativeSFTConfig 类提供了几个参数来定制训练

from trl import IterativeSFTConfig

config = IterativeSFTConfig(
    # Model initialization parameters
    model_init_kwargs={"torch_dtype": "bfloat16"},

    # Data preprocessing parameters
    max_length=512,
    truncation_mode="keep_end",

    # Training parameters
    output_dir="./output",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_steps=1000,
    save_steps=100,
    optim="adamw_torch",
    report_to="wandb",
)

模型初始化

您可以通过向 model_init_kwargs 传递关键字参数来控制模型的初始化方式

config = IterativeSFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
        "device_map": "auto",
        "trust_remote_code": True,
    }
)

数据预处理

训练器支持两种截断模式

  • keep_end:从序列的开头截断
  • keep_start:从序列的末尾截断
config = IterativeSFTConfig(
    max_length=512,
    truncation_mode="keep_end",  # or "keep_start"
)

训练优化

您可以优化 CUDA 缓存使用,以实现更节省内存的训练

config = IterativeSFTConfig(
    optimize_device_cache=True,
)

IterativeSFTTrainer

class trl.IterativeSFTTrainer

< >

( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.iterative_sft_config.IterativeSFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None )

参数

  • model (Union[str, PreTrainedModel]) — 待训练的模型。可以是:

    • 一个字符串,表示 huggingface.co 上模型仓库中预训练模型的 model id,或包含使用 save_pretrained 保存的模型权重的 目录 路径,例如 './my_model_directory/'。模型使用 from_pretrainedargs.model_init_kwargs 中的关键字参数加载。
    • 一个 PreTrainedModel 对象。仅支持因果语言模型。
  • args (IterativeSFTConfig, 可选, 默认为 None) — 此训练器的配置。如果为 None,则使用默认配置。
  • data_collator (DataCollator, 可选) — 用于从处理过的 train_dataseteval_dataset 的元素列表中形成批次的函数。如果未提供 processing_class,将默认为 default_data_collator;如果 processing_class 是特征提取器或分词器,则默认为 DataCollatorWithPadding 的实例。
  • eval_dataset (datasets.Dataset) — 用于评估的数据集。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, 可选, 默认为 None) — 用于处理数据的处理类。如果为 None,则处理类将从模型的名称使用 from_pretrained 加载。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用于训练的优化器和调度器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 在计算指标之前用于预处理 logits 的函数。
  • compute_metrics (Callable[[EvalPrediction], dict], 可选) — 用于计算指标的函数。必须接受一个 EvalPrediction 并返回一个从字符串到指标值的字典。

IterativeSFTTrainer 可用于通过需要在优化之间执行某些步骤的方法来微调模型。

训练 (train)

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

参数

  • resume_from_checkpoint (str or bool, 可选) — 如果是 str,则为先前 Trainer 实例保存的检查点的本地路径。如果是 bool 且等于 True,则加载先前 Trainer 实例在 *args.output_dir* 中保存的最后一个检查点。如果存在,训练将从此加载的模型/优化器/调度器状态恢复。
  • trial (optuna.Trialdict[str, Any], 可选) — 用于超参数搜索的试验运行或超参数字典。
  • ignore_keys_for_eval (list[str], 可选) — 你的模型输出中(如果它是一个字典)的键列表,在训练期间收集评估预测时应忽略这些键。
  • kwargs (dict[str, Any], 可选) — 用于隐藏已弃用参数的额外关键字参数

主训练入口点。

保存模型 (save_model)

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

将保存模型,以便您可以使用 `from_pretrained()` 重新加载它。

仅从主进程保存。

推送到 Hub (push_to_hub)

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

参数

  • commit_message (str, 可选, 默认为 "End of training") — 推送时提交的消息。
  • blocking (bool, 可选, 默认为 True) — 函数是否应在 git push 完成后才返回。
  • token (str, 可选, 默认为 None) — 具有写入权限的令牌,用于覆盖 Trainer 的原始参数。
  • revision (str, 可选) — 要提交的 git 修订版本。默认为“main”分支的头部。
  • kwargs (dict[str, Any], 可选) — 传递给 ~Trainer.create_model_card 的额外关键字参数。

将 `self.model` 和 `self.processing_class` 上传到 🤗 模型中心的 `self.args.hub_model_id` 存储库。

IterativeSFTConfig

class trl.IterativeSFTConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None max_length: typing.Optional[int] = None truncation_mode: str = 'keep_end' optimize_device_cache: bool = False )

控制模型的参数

  • model_init_kwargs (dict[str, Any]None, 可选, 默认为 None) — 用于 from_pretrained 的关键字参数,当 IterativeSFTTrainermodel 参数以字符串形式提供时使用。

控制数据预处理的参数

  • max_length (intNone, 可选, 默认为 None) — 分词后序列的最大长度。超过 max_length 的序列将被截断。
  • truncation_mode (str, 可选, 默认为 "keep_end") — 要使用的截断模式,可以是 "keep_end""keep_start"
  • optimize_device_cache (bool, 可选, 默认为 False) — 是否优化加速器缓存以实现内存效率稍高的训练。

IterativeSFTTrainer 的配置类。

此类仅包含特定于迭代式 SFT 训练的参数。有关训练参数的完整列表,请参阅 TrainingArguments 文档。请注意,此类中的默认值可能与 TrainingArguments 中的默认值不同。

使用 HfArgumentParser,我们可以将此类别转换为可在命令行上指定的 argparse 参数。

< > 在 GitHub 上更新