训练 Medusa 模型
本教程将向您展示如何在您选择的数据集上训练 Medusa 模型。请查看 推测文档,以获取有关 Medusa 工作原理和一般推测的更多信息。
训练 Medusa 模型的优势是什么?
训练 Medusa 头部可以大大提高生成速度。Medusa 向 LLM 添加额外的“头部”,以同时预测多个未来标记。当使用 Medusa 增强模型时,原始模型保持不变,只有新头部在训练期间被微调。
最重要的是拥有一个良好的数据集(与生产中将使用的数据类似),因为当生成在域内时,Medusa 的命中率要高得多。
如果您在与生产中将使用的数据集大不相同的数据集上训练 Medusa,那么该模型将无法准确地预测未来标记,因此加速将最小或不存在。
自蒸馏(生成训练数据)
有很多方法可以准备训练数据,但最简单、最有效的方法之一是“自蒸馏”数据。这意味着您可以使用相同的模型生成将用于训练模型的数据。
从本质上讲,您使用类似于生产中将使用的输入提示模型,然后模型将生成输出。
我们将使用此输出帮助训练 Medusa 头部,以预测序列中的n+1
、n+2
、n+3
等标记。
训练
Medusa 的原始实现可在 https://github.com/FasterDecoding/Medusa 找到,我们将按照原始仓库中描述的非常类似的过程来训练模型。
入门
有两种方法可以训练模型
torchrun
,它是torch.distributed.launch
的包装器- 一个支持 Medusa 的
axlotl
的分叉版本
在本教程中,我们将使用 torchrun
来训练模型,因为它是最直接的训练模型方式,但如果您愿意,也可以使用 axlotl
遵循类似步骤来训练模型。
使用 torchrun 进行训练
mkdir medusa-training
cd medusa-training
pyenv install 3.10
pyenv local 3.10
uv venv -p 3.10
source .venv/bin/activate
现在让我们克隆原始的 Medusa
仓库并安装库。
git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .
接下来我们需要一些数据来训练,我们可以使用 Hugging Face Hub 上提供的 ShareGPT_Vicuna_unfiltered
数据集。
apt install git-lfs
git lfs install
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
目前我们的目录结构如下所示
. ├── assets ├── CITATION.cff ├── create_data.py ├── data_generation ├── deepspeed.json ├── last_run_prepared ├── LICENSE ├── llm_judge ├── medusa ├── medusa_llm.egg-info ├── mistral.json ├── notebooks ├── pyproject.toml ├── README.md ├── ROADMAP.md ├── scripts ├── ShareGPT_Vicuna_unfiltered │ ├── README.md │ ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json │ └── ShareGPT_V4.3_unfiltered_cleaned_split.json ├── simple_gradio_interface.py ├── tiny-llama.json └── vicuna_7b_qlora_stage1
开始训练
现在让我们生成数据并开始训练模型。这个过程需要一段时间,因为我们是从模型中生成数据。
首先确保你有一个运行 TGI 的实例,其中包含你要用于自蒸馏的模型。
model=HuggingFaceH4/zephyr-7b-beta
volume=/home/ubuntu/.cache/huggingface/hub/
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model
现在我们可以使用 create_data.py
脚本来生成数据。
python create_data.py \ --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ --output-filename zephyr_self_distill.json
此时我们的终端应该看起来像这样
注意:在上面的屏幕截图中,我们只使用了数据集中的前 500 个示例来加快进程,你应该拥有一个更大的数据集进行训练。
现在我们可以开始训练模型,这是最有趣的部分了!
使用 torchrun
,我们可以轻松地使用 zephyr_self_distill.json
配置文件启动 medusa
训练脚本。
注意:如果你刚刚完成了自蒸馏,你的模型可能还在运行,请确保在开始训练之前停止它,以便所有资源都可用于训练。
WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \
--model_name_or_path HuggingFaceH4/zephyr-7b-beta \
--data_path zephyr_self_distill.json \
--bf16 True \
--output_dir zephyr_out \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-3 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--lazy_preprocess True \
--medusa_num_heads 3 \
--medusa_num_layers 1 \
--deepspeed deepspeed.json
如果成功,你应该看到类似于以下输出的内容
wandb: Run history:
wandb: train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb: train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb: train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁
wandb: train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁
wandb: train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb: train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇
wandb: train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁
wandb: train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇
wandb: train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb: train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇
wandb: train/total_flos ▁
wandb: train/train_loss ▁
wandb: train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb: train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb: train/epoch 2.0
wandb: train/global_step 16
wandb: train/learning_rate 0.0
wandb: train/loss 14.8906
wandb: train/medusa0_loss 4.25
wandb: train/medusa0_top1 0.28809
wandb: train/medusa1_loss 4.8125
wandb: train/medusa1_top1 0.22727
wandb: train/medusa2_loss 5.5
wandb: train/medusa2_top1 0.17293
wandb: train/total_flos 0.0
wandb: train/train_loss 23.98242
wandb: train/train_runtime 396.9266
wandb: train/train_samples_per_second 2.519
wandb: train/train_steps_per_second 0.04
最后但同样重要的是,不要忘记将此模型推送到 Hugging Face Hub,以便你可以在你的项目中使用它。
python -m medusa.hf_utils \ --folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \ --repo drbh/zephyr_medusa_demo
哇,我们成功训练了一个 Medusa 模型并将其推送到 Hugging Face Hub!🎉
< > GitHub 更新