使用 trl 和 DeepSpeed 进行分布式 SFT 第一部分:本地启动

社区文章 发布于 2025 年 1 月 23 日

引言

欢迎阅读本系列文章,其中记录了我首次尝试使用 trlDeepSpeed 运行分布式监督微调 (SFT) 任务时学到的经验教训。

在第一部分中,我将向您展示我如何按照官方 trl 文档运行我的第一个本地 SFT 实验。

在第二部分中,我们将利用并行训练在本地环境中完成完整的 SFT 任务。

最后,在第三部分中,我们将通过使用 Kubeflow 的训练操作器将相同的训练任务提交到 Kubernetes 集群,从而更进一步。

关于我自己的一个小注:我是一名软件开发工程师,对深度学习领域还比较陌生。如果这些文章对您来说太基础了,感谢您在我学习过程中给予的耐心。

先决条件

要遵循本教程,您需要一台至少配备一块 NVIDIA GPU 的机器。我在 V100 上运行了实验,没有遇到任何内存问题。如果您的 GPU 显存小于 32GB,您可能需要减少 per_device_train_batch_size 或考虑使用截断(尽管不推荐)以防止 CUDA 内存不足 (OOM) 错误。

您还需要以下依赖项

datasets
transformers
trl
torch

训练

trl 库提供了一些出色的示例训练脚本,我们将从这个开始:https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py

将脚本复制到您的开发机器(或笔记本)上,选择一个基础模型,然后选择一个 SFT 数据集来运行实验。对于本次实验,我选择了 Qwen/Qwen2.5-0.5B 作为基础模型,因为它体积小巧,并选择了 BAAI/Infinity-Instruct 作为 SFT 数据集(不知为何随机选择 😌)。您可以在这里探索其他有趣的数据集:https://github.com/mlabonne/llm-datasets

命令行参数

训练脚本 (sft.py) 提供了各种有用的命令行参数,可让您自定义微调过程。这些参数映射到以下类中的特定属性

您可以通过在它们前面加上 -- 直接从命令行传递任何这些参数。例如,传递 --dataset_name 将设置 trl.ScriptArguments 类中的 dataset_name 字段。

让我们看一下本教程中使用的参数

  • --model_name_or_path:指定要微调的基础模型。
  • --dataset_name:定义用于微调的数据集。
  • --dataset_config:某些数据集带有多个配置(版本)。此参数允许您选择要使用的版本。
  • --do_train:告诉脚本开始训练过程。
  • --per_device_train_batch_size:定义每个 GPU 的批处理大小。
  • --output_dir:指定模型将保存到的目录。
  • --max_steps:设置最大训练步数。
  • --logging_steps:设置训练期间记录日志的频率。

为方便起见,我更喜欢将完整的命令保存在 shell 脚本中,以便于执行。这是我为本教程使用的脚本

python sft.py \
  --model_name_or_path Qwen/Qwen2.5-0.5B \
  --dataset_name BAAI/Infinity-Instruct \
  --dataset_config 0625 \
  --do_train \
  --per_device_train_batch_size 4 \
  --output_dir /tmp/my-first-sft-exp \
  --max_steps 10 \
  --logging_steps 1

备注

  • 我选择了最小版本的数据集,并将实验限制在仅 10 步以内,以加快运行速度。
  • 由于训练只有 10 步,我将 --logging_steps 设置为 1,以便更频繁地查看日志。
  • --per_device_train_batch_size 设置为 4,因为这里的目标不是模型质量,而只是为了在没有 CUDA OOM 的情况下运行实验。任何适合您的显存的数字都应该可以。

2025-02-18 更新

trl 提供了一个方便的辅助函数来从 YAML 文件解析训练参数,您可以在此处找到更多详细信息。

通过此功能,您可以将上述训练参数保存在 YAML 文件(例如 recipe.yaml)中,如下所示

model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: BAAI/Infinity-Instruct
dataset_config: '0625'
do_train: true
per_device_train_batch_size: 4
output_dir: /tmp/my-first-sft-exp
max_steps: 10
logging_steps: 1

并启动训练,命令如下

python sft.py --config recipe.yaml

哦豁

现在,如果您使用相同的数据集并执行相同的脚本,您可能会遇到一个(不太有用)的错误消息

$ ./quickstart.sh 
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 50.35it/s]
Map:   0%|                                                                                                                                                                         | 0/659808 [00:00<?, ? examples/s]
Traceback (most recent call last):
  File "/home/jovyan/sft-walkthrough/sft.py", line 126, in <module>
    main(script_args, training_args, model_args)
  File "/home/jovyan/sft-walkthrough/sft.py", line 97, in main
    trainer = SFTTrainer(
  ...
  File "/home/jovyan/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 416, in tokenize
    element[dataset_text_field] if formatting_func is None else formatting_func(element),
  File "/home/jovyan/.local/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 277, in __getitem__
    value = self.data[key]
KeyError: 'text'

修复方法

2025-02-18 更新

  • 从 trl 0.15.0 开始(在 此 PR 中),不再支持“conversations”列。我们需要将其重命名为“messages”。
  • 此 PR 中(撰写本文时尚未发布),对“conversations”列的支持已恢复,整个预处理过程也得到了简化,我们不再需要手动映射字典键(“from”->“role”,“value”->“content”)。

2025-02-19 更新

上述 PR 已在 trl 0.15.1 中发布。

这个错误消息有点令人困惑——它声称 SFTTrainer 要求数据集具有“text”字段。然而,根据数据集格式和类型,“text”用于标准数据集,而“messages”应用于对话数据集。经过大量搜索,我发现了这个追踪问题这行代码这个函数。看来对于当前实现(trl == 0.13.0),我们有两种选择

  1. 自行格式化数据集(应用聊天模板)并将格式化后的数据放入“text”字段中。
  2. 以允许 trl 为我们处理转换的方式转换我们的数据集。

要使第二种选择生效,数据集必须

  • 包含“messages”或“conversations”字段。
  • “messages”(或“conversations”)字段中的每个元素都包含“content”字段和“role”字段。

检查我使用的数据集发现了一个不匹配之处:虽然它有一个“conversations”字段,但内部的元素使用“from”和“value”作为键,而不是“role”和“content”。作为一名懒惰的程序员,我选择了第二种方法,并相应地更新了训练脚本。此外,我还删除了数据集中所有其他字段,因为它们在 SFT 任务中未使用。删除它们将稍微减少内存占用并加快处理速度。

...
def main(script_args, training_args, model_args):
    ...
    ################
    # Dataset
    ################
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    def convert_fields(message: dict) -> dict:
        _message = {
          "role": message["from"],
          "content": message["value"],
        }
        # Qwen2.5 tokenizer only supports "system", "user", "assistant" and "tool"
        # See <https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/tokenizer_config.json#L198>
        if _message["role"] == "human":
            _message["role"] = "user"
        elif _message["role"] == "gpt":
            _message["role"] = "assistant"
        elif _message["role"] == "system":
            # nothing to be done.
            ...
        else:
            # In case there are any other roles, print them so we can improve in next iteration.
            print(_message["role"])
        return _message

    def convert_messages(example):
        example["conversations"] = [convert_fields(message) for message in example["conversations"]]
        return example

    # remove unused fields
    dataset = dataset.remove_columns(["id", "label", "langdetect", "source"]).map(convert_messages)
    ...

更新后,脚本运行没有任何问题!您应该能够看到如下训练日志

$ ./quickstart.sh 
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 17.26it/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659808/659808 [01:19<00:00, 8280.44 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659808/659808 [08:33<00:00, 1284.45 examples/s]
{'loss': 1.8859, 'grad_norm': 14.986075401306152, 'learning_rate': 1.8e-05, 'epoch': 0.0}                                                                                                                                     
{'loss': 1.4527, 'grad_norm': 13.9092378616333, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.0}                                                                                                                        
{'loss': 1.467, 'grad_norm': 7.388503074645996, 'learning_rate': 1.4e-05, 'epoch': 0.0}                                                                                                                                       
{'loss': 1.7757, 'grad_norm': 9.457520484924316, 'learning_rate': 1.2e-05, 'epoch': 0.0}                                                                                                                                      
{'loss': 1.9043, 'grad_norm': 10.256357192993164, 'learning_rate': 1e-05, 'epoch': 0.0}                                                                                                                                       
{'loss': 1.6163, 'grad_norm': 10.774249076843262, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}                                                                                                                       
{'loss': 1.1774, 'grad_norm': 5.897563457489014, 'learning_rate': 6e-06, 'epoch': 0.0}                                                                                                                                        
{'loss': 1.8093, 'grad_norm': 8.3130464553833, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}                                                                                                                          
{'loss': 1.8387, 'grad_norm': 7.102719306945801, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}                                                                                                                       
{'loss': 1.4251, 'grad_norm': 9.853829383850098, 'learning_rate': 0.0, 'epoch': 0.0}                                                                                                                                          
{'train_runtime': 38.8598, 'train_samples_per_second': 1.029, 'train_steps_per_second': 0.257, 'train_loss': 1.635251808166504, 'epoch': 0.0}                                                                                 
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:38<00:00,  3.89s/it]

结论

在第一部分中,我们逐步介绍了如何使用 trl 设置本地 SFT 实验。该库为使用自定义数据集微调 LLM 提供了用户友好的界面。我们还介绍了 trlSFTTrainer 所需的正确数据集格式以及如何预处理数据集以满足这些要求。

在下一部分中,我们将深入探讨如何使用单节点、多 GPU 配置在本地扩展此设置,以完成完整的 SFT 任务。此外,我们还将探索各种优化技术,以将更大的模型适配到您的 GPU 中并加速训练过程。敬请期待!

社区

注册登录以评论