在 Axolotl 中启用序列并行实现长上下文训练
本文最初发布于 Axolotl 的 Substack。你可以在此处阅读原文。
随着大型语言模型 (LLM) 在规模和上下文长度方面持续扩展,利用长上下文训练它们已成为一项重要的能力。然而,处理这些扩展序列会带来显著的内存挑战,即使在最大的 GPU 上也经常导致内存不足 (OOM) 错误。Axolotl 现在通过实现序列并行 (SP) 解决了这个问题,使研究人员和开发人员能够训练比以前可能更长上下文的模型。
什么是序列并行?
序列并行是一种将单个序列的处理分配到多个 GPU 上的技术。与其他分割模型参数(张量并行)或训练示例(数据并行)的并行方法不同,SP 将输入序列本身分割成块,每个 GPU 处理序列的一部分。
这种方法对于长上下文训练特别有效,因为它直接解决了主要限制:注意力机制中与序列长度相关的二次内存增长。
Axolotl 中的序列并行实现
Axolotl 的序列并行实现使用 ring-flash-attn
库,具体采用了 LLaMA-3 技术报告中的 llama3_flash_attn_varlen_func
实现。这种方法在保持计算效率的同时,将注意力计算分布到多个 GPU 上。
该实现的主要优点包括:
- 内存效率:通过在 GPU 之间分割序列,每个 GPU 所需的内存显著减少。
- 扩展能力:支持训练在可用硬件上原本不可能实现的序列长度。
- 可组合性:与样本打包和变长序列以及 Axolotl 支持的许多其他高级功能(Liger 内核、
torch.compile
、FSDP、DeepSpeed 等)兼容。
序列并行如何工作
启用序列并行后,注意力计算将根据指定的 sequence_parallel_degree
分布到 GPU 上。例如,如果将此值设置为 4,则每个序列将被分成 4 个等长块,每个块在不同的 GPU 上处理。
该实现处理:
- 分割序列
- 分配计算
- 管理 GPU 之间的通信
- 确保反向传播期间正确的梯度流
其中一些功能是集成 ring-flash-attn
包的结果!
在 Axolotl 中配置序列并行
在 Axolotl 中设置序列并行非常简单。安装 ring-flash-attn
后(建议通过 pip install axolotl[ring-flash-attn]
),您只需将 sequence_parallel_degree
参数添加到配置文件中。
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
sequence_parallel_degree
应设置为可用 GPU 数量的约数。例如,如果您有 8 个 GPU,则可以根据内存需求将其设置为 2、4 或 8。
重要注意事项:
- SP 程度必须能整除可用 GPU 数量、序列长度和模型中的注意力头数量。
- 必须启用 Flash Attention 才能使序列并行工作。
- 微批量大小必须设置为 1。
- 支持样本打包。
- 由于
ring-flash-attn
实现中的 GPU 间通信,使用 NVLink 连接的 GPU 将提高性能。 - RL 训练器支持即将推出!
内存节省与扩展
序列并行带来的内存节省是巨大的。理论上,当序列并行程度为 n
时,可以期望将注意力操作所需的内存减少约 n
倍。实际上,由于 GPU 之间通信的开销,内存节省会略低,并与其他训练配置参数相互作用。
例如,在 SP 程度为 4 时,您应该能够处理比单个 GPU 所能容纳的序列长 3-4 倍的序列。
基准测试结果
该基准测试评估了 LLaMA 3.1 8B 在不同训练方法和各种 GPU 型号下序列并行扩展性。我们生成了人工数据,然后确定了给定模型在给定数量 GPU(等于 SP 程度)上可以运行的最大序列长度(不会在训练期间 OOM)。为了进一步突破序列长度限制和训练速度,我们启用了 Liger kernels 并使用了 AdamW 8bit 优化器。
请注意,在这些基准测试中,我们没有使用梯度检查点等技术,这些技术会以牺牲速度为代价节省更多的 VRAM。例如,在 gradient_checkpointing: true
的情况下,我们成功地在 8 个 H100 上运行了 LLaMA 3.1 8B 的微调,上下文长度超过 25 万,尽管单个训练步骤耗时超过 1 分钟(!)。您可以在 Axolotl 配置中设置此值(或 gradient_checkpointing: offload
以获得更多 VRAM 节省,同时牺牲更多速度)进行测试,并进行实验以找到适合您用例的最佳值设置。
这些基准测试运行的配置可在此处和此处找到。基准测试代码可在此处和此处找到。
LLaMA 3.1 8B 完全微调
完全微调显示了稳健的上下文长度扩展,使用 8 个 GPU 时达到单 GPU 上下文长度的 5 倍。
正如预期,由于 GPU 间通信开销,“速度提升/GPU 数量”指标随着并行度的增加而下降,但主要优势仍然是能够处理越来越长的序列。请务必明智地选择 SP 程度和序列长度,以平衡 VRAM 节省和训练速度!
H100 基准测试
SP 程度 | 最大上下文长度 | 上下文扩展 | 每秒令牌数 | 平均时间(秒) | 加速比 | 加速/GPU 数量 |
---|---|---|---|---|---|---|
1 | 9,216 | 1.00倍 | 7,317.50 | 1.2594 | 1.00倍 | 100.0% |
2 | 12,288 | 1.33x | 11,010.05 | 1.1161 | 1.50x | 75.2% |
4 | 24,576 | 2.67x | 14,278.66 | 1.7212 | 1.95 倍 | 48.8% |
8 | 46,080 | 5.00x | 16,035.84 | 2.8736 | 2.19x | 27.4% |
LLaMA 3.1 8B QLoRA 微调
QLoRA 展示了卓越的上下文长度扩展能力,在我们的 H100 基准测试中实现了接近线性的扩展,在 4090 基准测试中实现了精确的线性扩展。4 位量化减少了内存占用,从而提高了上下文处理能力。
随着 GPU 间通信的增加,“加速/GPU 数量”指标呈下降趋势,这符合预期模式。再次强调,请尝试使用 SP 程度和序列长度,以平衡 VRAM 节省和训练速度。
H100 基准测试
SP 程度 | 最大上下文长度 | 上下文扩展 | 每秒令牌数 | 平均时间(秒) | 加速比 | 加速/GPU 数量 |
---|---|---|---|---|---|---|
1 | 17,408 | 1.00倍 | 9,104.38 | 1.912 | 1.00倍 | 100.0% |
2 | 34,816 | 2.00x | 15,806.46 | 2.2026 | 1.74x | 86.8% |
4 | 66,560 | 3.82x | 12,313.60 | 5.4054 | 1.35x | 33.8% |
8 | 129,024 | 7.41x | 11,096.06 | 11.6279 | 1.22x | 15.2% |
NVIDIA H100 上全微调和 4 位 QLoRA 训练下,上下文长度随 GPU 数量/SP 程度的扩展情况。
4090 基准测试
SP 程度 | 最大上下文长度 | 上下文扩展 | 每秒令牌数 | 平均时间(秒) | 加速比 | 加速/GPU 数量 |
---|---|---|---|---|---|---|
1 | 4,096 | 1.00倍 | 3,674.11 | 1.1148 | 1.00倍 | 100.0% |
2 | 8,192 | 2.00x | 5,021.70 | 1.6313 | 1.37倍 | 68.3% |
4 | 16,384 | 4.00x | 6,455.30 | 2.5381 | 1.76x | 43.9% |
8 | 32,768 | 8.00x | 3,244.03 | 10.101 | 0.88x | -11.0% |
QLoRA 训练中,NVIDIA 4090 上下文长度随 GPU 数量/SP 程度的精确线性扩展。
特性可组合性
我们的 SP 实现与 Axolotl 支持的其他几个重要的省时省内存优化兼容,包括(但不限于):
- Liger 内核
- FSDP
- DeepSpeed ZeRO 1-3
torch.compile
- 样本打包
请尝试使用您喜欢的 Axolotl 功能(如果您遇到任何问题,请在GitHub 上提交 issue)!
开始使用
要在 Axolotl 配置中开始使用序列并行,请确保您拥有多个 GPU,启用 Flash Attention,并适当设置序列并行度。