广义知识蒸馏训练器
概述
广义知识蒸馏 (GKD) 在 Rishabh Agarwal、Nino Vieillard、Yongchao Zhou、Piotr Stanczyk、Sabela Ramos、Matthieu Geist 和 Olivier Bachem 的论文 语言模型的在线策略蒸馏:从自我生成的错误中学习 中提出。
论文摘要如下:
知识蒸馏 (KD) 广泛用于压缩教师模型以降低其推理成本和内存占用,方法是训练一个较小的学生模型。然而,当前用于自回归序列模型的 KD 方法存在训练期间看到的输出序列与学生在推理期间生成的输出序列之间的分布不匹配问题。为了解决这个问题,我们引入了广义知识蒸馏 (GKD)。GKD 不仅仅依赖于一组固定的输出序列,而是通过利用教师对这些序列的反馈,在学生自身生成的输出序列上训练学生。与监督 KD 方法不同,GKD 还提供了在学生和教师之间使用替代损失函数的灵活性,这在学生缺乏模仿教师分布的表达能力时很有用。此外,GKD 促进了蒸馏与强化学习微调 (RLHF) 的无缝集成。我们证明了 GKD 在摘要、翻译和算术推理任务上蒸馏自回归语言模型的有效性,以及用于指令微调的与任务无关的蒸馏。
GKD 的关键方面是
- 它通过在学生模型自身生成的输出序列上训练学生模型来解决自回归序列模型中训练-推理分布不匹配的问题。
- GKD 允许通过广义 Jensen-Shannon 散度 (JSD) 选择学生和教师模型之间不同的散度度量,这在学生缺乏完全模仿教师的能力时很有用。
此训练后方法由 Kashif Rasul 和 Lewis Tunstall 贡献。
使用技巧
GKD Trainer 是一个围绕 SFTTrainer 类的包装器,它接收一个教师模型参数。它需要通过 GKDConfig 设置两个参数,即
lmbda
:控制学生数据比例,即基于策略的学生生成输出的比例。当lmbda=0.0
时,损失函数简化为监督 JSD,其中学生使用教师的 token 级概率进行训练。当lmbda=1.0
时,损失函数简化为基于策略的 JSD,其中学生生成输出序列,并从教师处获得这些序列的 token 特定反馈。对于 [0, 1] 之间的值,它根据每个批次的lmbda
值在两者之间随机选择。beta
:控制广义 Jensen-Shannon 散度中的插值。当beta=0.0
时,损失函数近似于前向 KL 散度,而当beta=1.0
时,损失函数近似于后向 KL 散度。对于 [0, 1] 之间的值,它在这两者之间进行插值。
作者发现,基于策略的数据 (高 lmbda
) 表现更好,并且最佳 beta
值取决于任务和评估方法。
基本 API 如下所示
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
train_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "Hi, how are you?"},
{"role": "assistant", "content": "I'm great thanks"},
]
]
* NUM_DUMMY_SAMPLES
}
)
eval_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What colour is the sky?"},
{"role": "assistant", "content": "The sky is blue"},
]
]
* NUM_DUMMY_SAMPLES
}
)
args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
model=model,
teacher_model=teacher_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
预期数据集格式
数据集应格式化为“消息”列表,其中每条消息都是一个字典列表,包含以下键:
role
:可以是system
、assistant
或user
content
:消息内容
GKDTrainer
generalized_jsd_loss
< 源代码 >( student_logits teacher_logits labels = None beta = 0.5 temperature = 1.0 reduction = 'batchmean' ) → 损失
返回值
损失
包含广义 JSD 损失的标量张量
使用 F.kl_div 计算用于知识蒸馏的广义 Jensen-Shannon 散度损失。有关定义,请参阅 https://arxiv.org/abs/2306.13649 的公式 (1)。
执行广义知识蒸馏 (GKD) 模型的训练步骤。
此方法实现了 GKD 论文中描述的基于策略的学习方法。以 self.lmbda
的概率,它使用学生模型生成新的响应,然后将其用于训练,而不是原始输入。
GKDConfig
类 trl.GKDConfig
< 源代码 >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: Union = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: Optional = None per_gpu_eval_batch_size: Optional = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: Optional = None eval_delay: Optional = 0 torch_empty_cache_steps: Optional = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: Union = 'linear' lr_scheduler_kwargs: Union = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: Optional = 'passive' log_level_replica: Optional = 'warning' log_on_each_node: bool = True logging_dir: Optional = None logging_strategy: Union = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: Union = 'steps' save_steps: float = 500 save_total_limit: Optional = None save_safetensors: Optional = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: Optional = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: Optional = None local_rank: int = -1 ddp_backend: Optional = None tpu_num_cores: Optional = None tpu_metrics_debug: bool = False debug: Union = '' dataloader_drop_last: bool = False eval_steps: Optional = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: Optional = None past_index: int = -1 run_name: Optional = None disable_tqdm: Optional = None remove_unused_columns: Optional = True label_names: Optional = None load_best_model_at_end: Optional = False metric_for_best_model: Optional = None greater_is_better: Optional = None ignore_data_skip: bool = False fsdp: Union = '' fsdp_min_num_params: int = 0 fsdp_config: Union = None fsdp_transformer_layer_cls_to_wrap: Optional = None accelerator_config: Union = None deepspeed: Union = None label_smoothing_factor: float = 0.0 optim: Union = 'adamw_torch' optim_args: Optional = None adafactor: bool = False group_by_length: bool = False length_column_name: Optional = 'length' report_to: Union = None ddp_find_unused_parameters: Optional = None ddp_bucket_cap_mb: Optional = None ddp_broadcast_buffers: Optional = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: Optional = None hub_model_id: Optional = None hub_strategy: Union = 'every_save' hub_token: Optional = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: Union = None include_inputs_for_metrics: bool = False include_for_metrics: List = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: Union = None push_to_hub_model_id: Optional = None push_to_hub_organization: Optional = None push_to_hub_token: Optional = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: Optional = None ray_scope: Optional = 'last' ddp_timeout: Optional = 1800 torch_compile: bool = False torch_compile_backend: Optional = None torch_compile_mode: Optional = None dispatch_batches: Optional = None split_batches: Optional = None include_tokens_per_second: Optional = False include_num_input_tokens_seen: Optional = False neftune_noise_alpha: Optional = None optim_target_modules: Union = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: Optional = False eval_use_gather_object: Optional = False dataset_text_field: Optional = None packing: bool = False max_seq_length: Optional = None dataset_num_proc: Optional = None dataset_batch_size: int = 1000 model_init_kwargs: Optional = None dataset_kwargs: Optional = None eval_packing: Optional = None num_of_sequences: int = 1024 chars_per_token: float = 3.6 use_liger: bool = False temperature: float = 0.9 lmbda: float = 0.5 beta: float = 0.5 max_new_tokens: int = 128 teacher_model_name_or_path: Optional = None teacher_model_init_kwargs: Optional = None disable_dropout: bool = True )
参数
- temperature (
float
, 可选, 默认为0.9
) — 采样温度。温度越高,完成结果越随机。 - lmbda (
float
,可选,默认为0.5
) — 控制学生数据比例(即策略内学生生成输出的比例)的 Lambda 参数。 - beta (
float
,可选,默认为0.5
) — 广义 Jensen-Shannon 散度损失的插值系数,介于0.0
和1.0
之间。当 beta 为0.0
时,损失为 KL 散度。当 beta 为1.0
时,损失为逆 KL 散度。 - max_new_tokens (
int
,可选,默认为128
) — 每次完成生成的最大标记数。 - teacher_model_name_or_path (
Optional[str]
,可选,默认为None
) — 教师模型的模型名称或路径。如果为None
,则教师模型将与正在训练的模型相同。 - teacher_model_init_kwargs (
Optional[Dict[str, Any]]
,可选,默认为None
) — 在使用字符串实例化教师模型时,传递给AutoModelForCausalLM.from_pretrained
的关键字参数。 - disable_dropout (
bool
,可选,默认为True
) — 是否禁用model
中的 dropout。
GKDTrainer 的配置类。