text-generation-inference 文档

训练 Medusa

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

训练 Medusa

本教程将向您展示如何在您选择的数据集上训练 Medusa 模型。请查看推测文档,以获取有关 Medusa 工作原理和一般推测的更多信息。

训练 Medusa 模型有什么好处?

训练 Medusa 头可以大大提高生成速度。Medusa 向 LLM 添加额外的“头”,以同时预测多个未来 token。当使用 Medusa 增强模型时,原始模型保持不变,并且仅在训练期间微调新的头。

最重要的事情之一是拥有良好的数据集(具有与生产中使用的数据相似的数据),因为当生成是领域内时,Medusa 的命中率要高得多。

如果您在与生产中使用的数据集非常不同的数据集上训练 Medusa,则该模型将无法准确预测未来的 token,因此加速将是最小的或不存在的。

自蒸馏(生成用于训练的数据)

有很多方法可以准备用于训练的数据,但最简单有效的方法之一是“自蒸馏”数据。这意味着您可以使用相同的模型来生成将用于训练模型的数据。

本质上,您可以使用与生产中使用的输入类似的输入来提示模型,模型将生成输出。

我们将使用此输出帮助训练 medusa 头,以预测序列中的 n+1n+2n+3 等 token。

训练

Medusa 的原始实现在 https://github.com/FasterDecoding/Medusa 上可用,我们将遵循与原始存储库中描述的非常相似的过程来训练模型。

开始入门

有两种方法可以训练模型

  • torchrun,它是 torch.distributed.launch 的包装器
  • axlotl 的一个分支版本,支持 Medusa

在本教程中,我们将使用 torchrun 来训练模型,因为这是训练模型最直接的方法,但如果您愿意,可以按照类似的步骤使用 axlotl 来训练模型。

使用 torchrun 进行训练

mkdir medusa-training
cd medusa-training

pyenv install 3.10
pyenv local 3.10

uv venv -p 3.10
source .venv/bin/activate

现在让我们克隆原始的 Medusa 存储库并安装库。

git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .

接下来,我们需要一些数据进行训练,我们可以使用 Hugging Face Hub 上提供的 ShareGPT_Vicuna_unfiltered 数据集。

apt install git-lfs
git lfs install
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered

目前我们的目录结构如下所示

.
├── assets
├── CITATION.cff
├── create_data.py
├── data_generation
├── deepspeed.json
├── last_run_prepared
├── LICENSE
├── llm_judge
├── medusa
├── medusa_llm.egg-info
├── mistral.json
├── notebooks
├── pyproject.toml
├── README.md
├── ROADMAP.md
├── scripts
├── ShareGPT_Vicuna_unfiltered
│   ├── README.md
│   ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json
│   └── ShareGPT_V4.3_unfiltered_cleaned_split.json
├── simple_gradio_interface.py
├── tiny-llama.json
└── vicuna_7b_qlora_stage1

开始训练

现在让我们生成数据并开始训练模型。此过程将花费一些时间,因为我们正在从模型生成数据。

首先,请确保您有一个 TGI 实例正在运行,其中包含您要用于自蒸馏的模型。

model=HuggingFaceH4/zephyr-7b-beta
volume=/home/ubuntu/.cache/huggingface/hub/

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model

现在我们可以使用 create_data.py 脚本生成数据。

python create_data.py \
    --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
    --output-filename zephyr_self_distill.json

此时,我们的终端应如下所示

注意:在上面的屏幕截图中,我们仅使用了数据集中的前 500 个示例来加快处理速度,您应该拥有更大的数据集用于训练。

现在我们终于可以开始有趣的部分并开始训练模型了!

使用 torchrun,我们可以使用 zephyr_self_distill.json 配置文件轻松启动 medusa 训练脚本。

注意:如果您刚刚进行了自蒸馏,则可能仍在运行模型,请务必在开始训练之前停止它,以便允许所有资源用于训练。

WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \
    --model_name_or_path HuggingFaceH4/zephyr-7b-beta \
    --data_path zephyr_self_distill.json \
    --bf16 True \
    --output_dir zephyr_out \
    --num_train_epochs 5 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "no" \
    --learning_rate 1e-3 \
    --weight_decay 0.0 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --lazy_preprocess True \
    --medusa_num_heads 3 \
    --medusa_num_layers 1 \
    --deepspeed deepspeed.json

如果成功,您应该看到与下面类似的输出

wandb: Run history:
wandb:                    train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb:              train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb:            train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁
wandb:                     train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁
wandb:             train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb:             train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇
wandb:             train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁
wandb:             train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇
wandb:             train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb:             train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:                    train/epoch 2.0
wandb:              train/global_step 16
wandb:            train/learning_rate 0.0
wandb:                     train/loss 14.8906
wandb:             train/medusa0_loss 4.25
wandb:             train/medusa0_top1 0.28809
wandb:             train/medusa1_loss 4.8125
wandb:             train/medusa1_top1 0.22727
wandb:             train/medusa2_loss 5.5
wandb:             train/medusa2_top1 0.17293
wandb:               train/total_flos 0.0
wandb:               train/train_loss 23.98242
wandb:            train/train_runtime 396.9266
wandb: train/train_samples_per_second 2.519
wandb:   train/train_steps_per_second 0.04

最后但最重要的是,不要忘记将此模型推送到 Hugging Face Hub,以便您可以在项目中使用它。

python -m medusa.hf_utils \
    --folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \
    --repo drbh/zephyr_medusa_demo

哇,我们已成功训练了一个 Medusa 模型并将其推送到了 Hugging Face Hub!🎉

< > 在 GitHub 上更新