使用 trl 和 DeepSpeed 进行分布式 SFT 第一部分:本地启动
引言
欢迎阅读本系列文章,其中记录了我首次尝试使用 trl 和 DeepSpeed 运行分布式监督微调 (SFT) 任务时学到的经验教训。
在第一部分中,我将向您展示我如何按照官方 trl 文档运行我的第一个本地 SFT 实验。
在第二部分中,我们将利用并行训练在本地环境中完成完整的 SFT 任务。
最后,在第三部分中,我们将通过使用 Kubeflow 的训练操作器将相同的训练任务提交到 Kubernetes 集群,从而更进一步。
关于我自己的一个小注:我是一名软件开发工程师,对深度学习领域还比较陌生。如果这些文章对您来说太基础了,感谢您在我学习过程中给予的耐心。
先决条件
要遵循本教程,您需要一台至少配备一块 NVIDIA GPU 的机器。我在 V100 上运行了实验,没有遇到任何内存问题。如果您的 GPU 显存小于 32GB,您可能需要减少 per_device_train_batch_size
或考虑使用截断(尽管不推荐)以防止 CUDA 内存不足 (OOM) 错误。
您还需要以下依赖项
datasets
transformers
trl
torch
训练
trl
库提供了一些出色的示例训练脚本,我们将从这个开始:https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py
将脚本复制到您的开发机器(或笔记本)上,选择一个基础模型,然后选择一个 SFT 数据集来运行实验。对于本次实验,我选择了 Qwen/Qwen2.5-0.5B 作为基础模型,因为它体积小巧,并选择了 BAAI/Infinity-Instruct 作为 SFT 数据集(不知为何随机选择 😌)。您可以在这里探索其他有趣的数据集:https://github.com/mlabonne/llm-datasets。
命令行参数
训练脚本 (sft.py
) 提供了各种有用的命令行参数,可让您自定义微调过程。这些参数映射到以下类中的特定属性
您可以通过在它们前面加上 --
直接从命令行传递任何这些参数。例如,传递 --dataset_name
将设置 trl.ScriptArguments
类中的 dataset_name
字段。
让我们看一下本教程中使用的参数
--model_name_or_path
:指定要微调的基础模型。--dataset_name
:定义用于微调的数据集。--dataset_config
:某些数据集带有多个配置(版本)。此参数允许您选择要使用的版本。--do_train
:告诉脚本开始训练过程。--per_device_train_batch_size
:定义每个 GPU 的批处理大小。--output_dir
:指定模型将保存到的目录。--max_steps
:设置最大训练步数。--logging_steps
:设置训练期间记录日志的频率。
为方便起见,我更喜欢将完整的命令保存在 shell 脚本中,以便于执行。这是我为本教程使用的脚本
python sft.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name BAAI/Infinity-Instruct \
--dataset_config 0625 \
--do_train \
--per_device_train_batch_size 4 \
--output_dir /tmp/my-first-sft-exp \
--max_steps 10 \
--logging_steps 1
备注
- 我选择了最小版本的数据集,并将实验限制在仅 10 步以内,以加快运行速度。
- 由于训练只有 10 步,我将
--logging_steps
设置为 1,以便更频繁地查看日志。 --per_device_train_batch_size
设置为 4,因为这里的目标不是模型质量,而只是为了在没有 CUDA OOM 的情况下运行实验。任何适合您的显存的数字都应该可以。
2025-02-18 更新
trl
提供了一个方便的辅助函数来从 YAML 文件解析训练参数,您可以在此处找到更多详细信息。通过此功能,您可以将上述训练参数保存在 YAML 文件(例如
recipe.yaml
)中,如下所示
model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: BAAI/Infinity-Instruct dataset_config: '0625' do_train: true per_device_train_batch_size: 4 output_dir: /tmp/my-first-sft-exp max_steps: 10 logging_steps: 1
并启动训练,命令如下
python sft.py --config recipe.yaml
哦豁
现在,如果您使用相同的数据集并执行相同的脚本,您可能会遇到一个(不太有用)的错误消息
$ ./quickstart.sh
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 50.35it/s]
Map: 0%| | 0/659808 [00:00<?, ? examples/s]
Traceback (most recent call last):
File "/home/jovyan/sft-walkthrough/sft.py", line 126, in <module>
main(script_args, training_args, model_args)
File "/home/jovyan/sft-walkthrough/sft.py", line 97, in main
trainer = SFTTrainer(
...
File "/home/jovyan/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 416, in tokenize
element[dataset_text_field] if formatting_func is None else formatting_func(element),
File "/home/jovyan/.local/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 277, in __getitem__
value = self.data[key]
KeyError: 'text'
修复方法
2025-02-18 更新
- 从 trl 0.15.0 开始(在 此 PR 中),不再支持“conversations”列。我们需要将其重命名为“messages”。
- 在 此 PR 中(撰写本文时尚未发布),对“conversations”列的支持已恢复,整个预处理过程也得到了简化,我们不再需要手动映射字典键(“from”->“role”,“value”->“content”)。
2025-02-19 更新
上述 PR 已在 trl 0.15.1 中发布。
这个错误消息有点令人困惑——它声称 SFTTrainer
要求数据集具有“text”字段。然而,根据数据集格式和类型,“text”用于标准数据集,而“messages”应用于对话数据集。经过大量搜索,我发现了这个追踪问题、这行代码和这个函数。看来对于当前实现(trl == 0.13.0
),我们有两种选择
- 自行格式化数据集(应用聊天模板)并将格式化后的数据放入“text”字段中。
- 以允许
trl
为我们处理转换的方式转换我们的数据集。
要使第二种选择生效,数据集必须
- 包含“messages”或“conversations”字段。
- “messages”(或“conversations”)字段中的每个元素都包含“content”字段和“role”字段。
检查我使用的数据集发现了一个不匹配之处:虽然它有一个“conversations”字段,但内部的元素使用“from”和“value”作为键,而不是“role”和“content”。作为一名懒惰的程序员,我选择了第二种方法,并相应地更新了训练脚本。此外,我还删除了数据集中所有其他字段,因为它们在 SFT 任务中未使用。删除它们将稍微减少内存占用并加快处理速度。
...
def main(script_args, training_args, model_args):
...
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
def convert_fields(message: dict) -> dict:
_message = {
"role": message["from"],
"content": message["value"],
}
# Qwen2.5 tokenizer only supports "system", "user", "assistant" and "tool"
# See <https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/tokenizer_config.json#L198>
if _message["role"] == "human":
_message["role"] = "user"
elif _message["role"] == "gpt":
_message["role"] = "assistant"
elif _message["role"] == "system":
# nothing to be done.
...
else:
# In case there are any other roles, print them so we can improve in next iteration.
print(_message["role"])
return _message
def convert_messages(example):
example["conversations"] = [convert_fields(message) for message in example["conversations"]]
return example
# remove unused fields
dataset = dataset.remove_columns(["id", "label", "langdetect", "source"]).map(convert_messages)
...
更新后,脚本运行没有任何问题!您应该能够看到如下训练日志
$ ./quickstart.sh
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 17.26it/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659808/659808 [01:19<00:00, 8280.44 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659808/659808 [08:33<00:00, 1284.45 examples/s]
{'loss': 1.8859, 'grad_norm': 14.986075401306152, 'learning_rate': 1.8e-05, 'epoch': 0.0}
{'loss': 1.4527, 'grad_norm': 13.9092378616333, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.0}
{'loss': 1.467, 'grad_norm': 7.388503074645996, 'learning_rate': 1.4e-05, 'epoch': 0.0}
{'loss': 1.7757, 'grad_norm': 9.457520484924316, 'learning_rate': 1.2e-05, 'epoch': 0.0}
{'loss': 1.9043, 'grad_norm': 10.256357192993164, 'learning_rate': 1e-05, 'epoch': 0.0}
{'loss': 1.6163, 'grad_norm': 10.774249076843262, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}
{'loss': 1.1774, 'grad_norm': 5.897563457489014, 'learning_rate': 6e-06, 'epoch': 0.0}
{'loss': 1.8093, 'grad_norm': 8.3130464553833, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}
{'loss': 1.8387, 'grad_norm': 7.102719306945801, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}
{'loss': 1.4251, 'grad_norm': 9.853829383850098, 'learning_rate': 0.0, 'epoch': 0.0}
{'train_runtime': 38.8598, 'train_samples_per_second': 1.029, 'train_steps_per_second': 0.257, 'train_loss': 1.635251808166504, 'epoch': 0.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:38<00:00, 3.89s/it]
结论
在第一部分中,我们逐步介绍了如何使用 trl
设置本地 SFT 实验。该库为使用自定义数据集微调 LLM 提供了用户友好的界面。我们还介绍了 trl
的 SFTTrainer
所需的正确数据集格式以及如何预处理数据集以满足这些要求。
在下一部分中,我们将深入探讨如何使用单节点、多 GPU 配置在本地扩展此设置,以完成完整的 SFT 任务。此外,我们还将探索各种优化技术,以将更大的模型适配到您的 GPU 中并加速训练过程。敬请期待!