使用 PyTorch FSDP 微调 Llama 2 70B

发布于 2023 年 9 月 13 日
在 GitHub 上更新

引言

在这篇博文中,我们将探讨如何使用 PyTorch FSDP 微调 Llama 2 70B 以及相关的最佳实践。我们将利用 Hugging Face Transformers、Accelerate 和 TRL。我们还将学习如何将 Accelerate 与 SLURM 结合使用。

全分片数据并行 (FSDP) 是一种范例,其中优化器状态、梯度和参数在设备之间分片。在前向传播过程中,每个 FSDP 单元执行一次 *all-gather 操作*以获取完整的权重,然后执行计算,然后丢弃来自其他设备的分片。在前向传播之后,计算损失,然后进行反向传播。在反向传播中,每个 FSDP 单元执行一次 all-gather 操作以获取完整的权重,并执行计算以获取局部梯度。这些局部梯度通过 *reduce-scatter 操作*在设备之间平均并分片,以便每个设备可以更新其分片的参数。有关 PyTorch FSDP 的更多信息,请参阅这篇博文:使用 PyTorch 全分片数据并行加速大型模型训练

FSDP Workflow

(来源:链接

使用的硬件

节点数量:2。最低要求为 1。
每个节点的 GPU 数量:8
GPU 类型:A100
GPU 显存:80GB
节点内连接:NVLink
每个节点的内存:1TB
每个节点的 CPU 核心数:96
节点间连接:弹性结构适配器

微调 LLaMa 70B 的挑战

在尝试使用 FSDP 微调 LLaMa 70B 时,我们遇到了三个主要挑战

  1. FSDP 在加载预训练模型后封装模型。如果节点内的每个进程/rank 都加载 Llama-70B 模型,它将需要 70 * 4 * 8 GB ~ 2TB 的 CPU RAM,其中 4 是每个参数的字节数,8 是每个节点上的 GPU 数量。这将导致 CPU RAM 内存不足,从而导致进程终止。

  2. 使用 FULL_STATE_DICT 并将 CPU 卸载到 rank 0 上保存整个中间检查点需要大量时间,并且由于广播期间无限期挂起,通常会导致 NCCL 超时错误。然而,在训练结束时,我们希望获得整个模型状态字典,而不是仅与 FSDP 兼容的分片状态字典。

  3. 我们需要提高速度并减少 VRAM 使用量,以加快训练速度并节省计算成本。

让我们看看如何解决上述挑战并微调 70B 模型!

在我们开始之前,以下是重现我们结果所需的所有资源

  1. 代码库:https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/sft/training 带 flash-attn V2

  2. FSDP 配置:https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml

  3. SLURM 脚本 launch.slurmhttps://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25

  4. 模型:meta-llama/Llama-2-70b-chat-hf

  5. 数据集:smangrul/code-chat-assistant-v1(LIMA+GUANACO 的混合,格式正确,可直接用于训练)

先决条件

首先,按照这些步骤安装 Flash Attention V2:Dao-AILab/flash-attention: 快速且内存高效的精确注意力 (github.com)。安装 PyTorch 的最新 nightly 版本,CUDA ≥11.8。按照 DHS-LLM-Workshop/code_assistant/training/requirements.txt 安装其余要求。在这里,我们将从主分支安装 🤗 Accelerate 和 🤗 Transformers。

微调

解决挑战 1

PR huggingface/transformers#25107huggingface/accelerate#1777 解决了第一个挑战,并且不需要用户进行代码更改。它执行以下操作:

  1. 在所有 rank 上创建无权重的模型(使用 meta 设备)。
  2. 仅在 rank==0 上加载状态字典,并使用该状态字典在 rank 0 上设置模型权重。
  3. 对于所有其他 rank,对 meta 设备上的每个参数执行 torch.empty(*param.size(), dtype=dtype)
  4. 因此,rank==0 将加载具有正确状态字典的模型,而所有其他 rank 将具有随机权重。
  5. 设置 sync_module_states=True,以便 FSDP 对象在训练开始前负责将其广播到所有 rank。

以下是 7B 模型在 2 个 GPU 上测量的内存消耗和模型参数在不同阶段的输出片段。我们可以观察到,在加载预训练模型期间,rank 0 和 rank 1 的 CPU 总峰值内存分别为 32744 MB1506 MB。因此,只有 rank 0 正在加载预训练模型,从而实现了 CPU RAM 的高效利用。完整的日志可以在这里找到。

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744

accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506

解决挑战 2

通过在创建 FSDP 配置时选择 SHARDED_STATE_DICT 状态字典类型来解决。SHARDED_STATE_DICT 会单独保存每个 GPU 的分片,从而可以快速保存或从中间检查点恢复训练。当使用 FULL_STATE_DICT 时,第一个进程(rank 0)会在 CPU 上收集整个模型,然后以标准格式保存。

让我们使用以下命令创建 Accelerate 配置:

accelerate config --config_file "fsdp_config.yaml"

fsdp_config

生成的配置在这里:fsdp_config.yaml。在这里,分片策略是 FULL_SHARD。我们使用 TRANSFORMER_BASED_WRAP 进行自动包装策略,它使用 _no_split_module 来查找嵌套 FSDP 自动包装的 Transformer 块名称。我们使用 SHARDED_STATE_DICT 来保存中间检查点和优化器状态,这是 PyTorch 团队推荐的格式。请确保在训练开始时启用从 rank 0 广播模块参数,如上一段中解决挑战 1 所述。我们正在启用 bf16 混合精度训练。

对于最终检查点,整个模型状态字典,使用以下代码片段。

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir) # alternatively, trainer.push_to_hub() if the whole ckpt is below 50GB as the LFS limit per file is 50GB 

解决挑战 3

Flash Attention 和启用梯度检查点对于加快训练速度和减少 VRAM 使用至关重要,从而实现微调并节省计算成本。代码库目前使用猴子补丁,实现位于 chat_assistant/training/llama_flash_attn_monkey_patch.py

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 引入了一种计算精确注意力的方法,该方法通过利用底层硬件/GPU 的内存层次结构知识,实现更快和更内存高效的计算——内存的带宽/速度越高,其容量就越小,因为它变得更昂贵。

如果我们遵循博客 让深度学习 Go Brrrr 的第一性原理,我们可以发现当前硬件上的 Attention 模块是 内存密集型/带宽密集型。原因是 Attention 主要由逐元素操作组成,如下图左侧所示。我们可以观察到,掩码、softmax 和 dropout 操作占据了大部分时间,而不是占据大部分 FLOP 的矩阵乘法。

Attention Bottlenecks

(来源:链接

这正是 Flash Attention 所解决的问题。其理念是**消除冗余的 HBM 读/写。**它通过将所有内容保留在 SRAM 中,执行所有中间步骤,然后才将最终结果写回 HBM 来实现这一点,这也称为**内核融合**。下面是其如何克服内存瓶颈的图示。kernel_fusion

(来源:链接

在正向和反向传播过程中,使用**切片**将 NxN softmax/分数计算分块,以克服 SRAM 内存大小的限制。为了启用切片,使用了在线 softmax 算法。在反向传播过程中,使用**重新计算**,以避免在正向传播过程中存储整个 NxN softmax/分数矩阵。这大大减少了内存消耗。

要简化和深入了解 Flash Attention,请参阅博客文章 ELI5: FlashAttention让深度学习 Go Brrrr 从第一性原理,以及原始论文 FlashAttention: 快速且内存高效的精确注意力,具有 IO 感知

整合所有功能

要使用 Accelerate 启动器和 SLURM 运行训练,请参阅此 gist launch.slurm。下面是演示如何使用 Accelerate 启动器运行训练的等效命令。请注意,我们正在覆盖 fsdp_config.yamlmain_process_ipmain_process_portmachine_ranknum_processesnum_machines 值。这里,另一个重要的一点是存储在所有节点之间共享。

accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$MACHINE_RANK \
    --num_processes 16 \
    --num_machines 2 \
    train.py \
    --seed 100 \
    --model_name "meta-llama/Llama-2-70b-chat-hf" \
    --dataset_name "smangrul/code-chat-assistant-v1" \
    --chat_template_format "none" \
    --add_special_tokens False \
    --append_concat_token False \
    --splits "train,test" \
    --max_seq_len 2048 \
    --max_steps 500 \
    --logging_steps 25 \
    --log_level "info" \
    --eval_steps 100 \
    --save_steps 250 \
    --logging_strategy "steps" \
    --evaluation_strategy "steps" \
    --save_strategy "steps" \
    --push_to_hub \
    --hub_private_repo True \
    --hub_strategy "every_save" \
    --bf16 True \
    --packing True \
    --learning_rate 5e-5 \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1.0 \
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing True \
    --use_reentrant False \
    --dataset_text_field "content" \
    --use_flash_attn True \
    --ddp_timeout 5400 \
    --optim paged_adamw_32bit 

微调在约 13.5 小时内完成,以下是训练损失图。

train_loss

重要提示:在这里,如果您只使用一台带有 8 个 A100 80GB GPU 的节点,则需要 bitsandbytes 中的 paged_adamw_32bit 优化器。有关分页优化器的更多信息,请参阅 https://huggingface.co/docs/bitsandbytes/main/en/optimizers#paged-optimizers

下面是使用上述模型的一个对话示例

System Prompt: You are a helpful, respectful and honest assistant. Always answer as helpfully \
as possible, while being safe. Your answers should not include any harmful, \
unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that \
your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why \
instead of answering something not correct. If you don’t know the answer to a \
question, please don’t share false information.

- Human: What is Deep Learning? Explain like a Pirate.

+ Assistant: Arrr Grumete! Are ye lookin' fer a proper explanation o' Deep Learnin'? 
+ Alright then, listen up me hearty! Deep Learnin' be a type o' machine learnin' that uses artificial neural networks to analyze data. 
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but instead o' using neurons, they use interconnected nodes or "artificial neurons" to process information. 
+ These nodes work together to form layers o' processing units that can learn from large amounts o' data. 
+ The more data these systems receive, the better they get at recognizing patterns and makin' decisions on their own. 
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It's a treasure trove o' knowledge just waitin' to be discovered. 
+ Fair winds and following seas!

- Human: Now explain it like a chef.

+ Assistant: Certainly! Here's an explanation of deep learning from a culinary perspective:

+ Think of a recipe as a sequence of steps used to transform raw ingredients into a delicious dish. 
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) that are combined and transformed through various operations to produce a final output or prediction. 
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the desired outcome. 
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions. 
+ And much like how chefs must carefully balance the amount of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to achieve the perfect blend of accuracy and efficiency. 
+ With enough practice and experimentation, anyone can become a master chef... I mean, a master of deep learning! Bon appétit!

整个对话格式如下:

<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...

结论

我们成功地使用 PyTorch FSDP 在多节点多 GPU 环境中微调了 70B Llama 模型,同时解决了各种挑战。我们了解了 🤗 Transformers 和 🤗 Accelerates 如何支持高效初始化大型模型,以克服 CPU RAM 内存不足的问题。随后,我们介绍了保存/加载中间检查点以及如何以可直接使用的方式保存最终模型的推荐实践。为了加快训练速度并减少 GPU 内存使用,我们强调了 Flash Attention 和梯度检查点的重要性。总的来说,我们可以看到使用 🤗 Accelerate 的简单配置如何实现在多节点多 GPU 环境中微调如此庞大的模型。

社区

注册登录 以发表评论