TRL 文档

通用知识蒸馏训练器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

通用知识蒸馏训练器

概述

通用知识蒸馏 (GKD) 在 语言模型的在线蒸馏:从自我生成的错误中学习 中被提出,作者为 Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, 和 Olivier Bachem。

该论文的摘要如下:

知识蒸馏 (KD) 被广泛用于压缩教师模型,以降低其推理成本和内存占用,通过训练一个更小的学生模型。然而,当前用于自回归序列模型的 KD 方法存在训练期间看到的输出序列与学生在推理期间生成的输出序列之间的分布不匹配问题。为了解决这个问题,我们引入了通用知识蒸馏 (GKD)。GKD 不是仅仅依赖于一组固定的输出序列,而是通过利用教师对这些序列的反馈,在学生自我生成的输出序列上训练学生。与监督 KD 方法不同,GKD 还提供了在学生和教师之间采用替代损失函数的灵活性,当学生缺乏模仿教师分布的表达能力时,这可能很有用。此外,GKD 促进了蒸馏与 RL 微调 (RLHF) 的无缝集成。我们证明了 GKD 在摘要、翻译和算术推理任务上蒸馏自回归语言模型以及用于指令调优的任务无关蒸馏的有效性。

GKD 的关键方面是:

  1. 它通过在学生模型自我生成的输出序列上训练学生模型,解决了自回归序列模型中训练-推理分布不匹配的问题。
  2. GKD 允许通过广义 Jensen-Shannon 散度 (JSD) 在学生模型和教师模型之间灵活选择不同的散度度量,当学生缺乏完全模仿教师的能力时,这可能很有用。

这种后训练方法由 Kashif RasulLewis Tunstall 贡献。

使用技巧

GKDTrainerSFTTrainer 类的包装器,它接受教师模型参数。它需要通过 GKDConfig 设置三个参数,即:

  • lmbda:控制学生数据比例,即在线学生生成输出的比例。当 lmbda=0.0 时,损失简化为监督 JSD,其中学生使用教师的 token 级别概率进行训练。当 lmbda=1.0 时,损失简化为在线 JSD,其中学生生成输出序列,并从教师处获得关于这些序列的 token 特定反馈。对于 [0, 1] 之间的值,它根据每个批次的 lmbda 值在两者之间随机选择。
  • seq_kd:控制是否执行序列级别 KD(可以看作是对教师生成输出的监督 FT)。当 seq_kd=Truelmbda=0.0 时,损失简化为监督 JSD,其中教师生成输出序列,学生从教师处获得关于这些序列的 token 特定反馈。
  • beta:控制广义 Jensen-Shannon 散度中的插值。当 beta=0.0 时,损失近似于前向 KL 散度,而当 beta=1.0 时,损失近似于反向 KL 散度。对于 [0, 1] 之间的值,它在两者之间进行插值。

作者发现在线数据(高 lmbda)表现更好,最佳 beta 值因任务和评估方法而异。

在训练 Gemma 模型时,请确保 attn_implementation="flash_attention_2"。否则,由于此架构采用的 软上限技术,您将在 logits 中遇到 NaNs。

基本 API 如下:

from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

预期数据集类型

数据集应格式化为“messages”列表,其中每个消息都是字典列表,包含以下键:

  • rolesystemassistantuser 之一
  • content:消息内容

GKDTrainer

class trl.GKDTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None teacher_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str] = None args: typing.Optional[trl.trainer.gkd_config.GKDConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Optional[typing.Callable] = None )

generalized_jsd_loss

< >

( student_logits teacher_logits labels = None beta = 0.5 temperature = 1.0 reduction = 'batchmean' ) loss

参数

  • student_logits — 形状为 (batch_size, sequence_length, vocab_size) 的张量
  • teacher_logits — 形状为 (batch_size, sequence_length, vocab_size) 的张量
  • labels — 形状为 (batch_size, sequence_length) 的张量,其中填充 token 为 -100,在计算损失时忽略
  • beta — 介于 0 和 1 之间的插值系数(默认值:0.5)
  • temperature — Softmax 温度(默认值:1.0)
  • reduction — 指定应用于输出的归约方式(默认值:‘batchmean’)

返回值

loss

包含广义 JSD 损失的标量张量

使用 F.kl_div 计算知识蒸馏的广义 Jensen-Shannon 散度损失。有关定义,请参见 https://huggingface.ac.cn/papers/2306.13649 中的公式 (1) 。

training_step

< >

( model: Module inputs: dict num_items_in_batch: typing.Optional[int] = None )

执行通用知识蒸馏 (GKD) 模型的训练步骤。

此方法实现了 GKD 论文中描述的在线学习方法。以 self.lmbda 的概率,它使用学生模型生成新的响应,这些响应随后用于训练,而不是原始输入。

GKDConfig

class trl.GKDConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None tp_size: typing.Optional[int] = 0 fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_idpush_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metricseval_on_startuse_liger_kerneleval_use_gather_objectaverage_tokens_across_devicesmodel_init_kwargsdataset_text_fielddataset_kwargsdataset_num_procpad_tokenmax_lengthpackingpadding_freeeval_packingdataset_batch_sizenum_of_sequenceschars_per_tokenmax_seq_lengthuse_ligertemperaturelmbdabetamax_new_tokensteacher_model_name_or_pathteacher_model_init_kwargsdisable_dropoutseq_kd)

参数

  • temperature (float, optional, defaults to 0.9) — 采样温度。温度越高,补全内容就越随机。
  • lmbda (float, optional, defaults to 0.5) — Lambda 参数,用于控制学生数据比例(即,策略内学生生成的输出的比例)。
  • beta (float, optional, defaults to 0.5) — 广义 Jensen-Shannon 散度损失在 0.01.0 之间的插值系数。当 beta 为 0.0 时,损失为 KL 散度。当 beta 为 1.0 时,损失为逆 KL 散度。
  • max_new_tokens (int, optional, defaults to 128) — 每次补全生成的最大 token 数。
  • teacher_model_name_or_path (str or None, optional, defaults to None) — 教师模型的模型名称或路径。如果为 None,则教师模型将与正在训练的模型相同。
  • teacher_model_init_kwargs (dict[str, Any]] or None, optional, defaults to None) — 从字符串实例化教师模型时,传递给 AutoModelForCausalLM.from_pretrained 的关键字参数。
  • disable_dropout (bool, optional, defaults to True) — 是否禁用模型中的 dropout。
  • seq_kd (bool, optional, defaults to False) — Seq_kd 参数,用于控制是否执行序列级 KD(可以看作是对教师生成的输出进行监督式微调)。

用于 GKDTrainer 的配置类。

< > GitHub 上更新