低精度训练方法
新型硬件的发布导致了新的训练范式的出现,这些范式能够更好地利用它们。目前,这以使用 TransformersEngine (TE) 或 MS-AMP 等软件包进行 8 位精度训练的形式出现。
为了更好地了解今天讨论的主题,我们建议您阅读 低精度使用指南,因为本文档将定期引用该指南。
快速图表
以下是 MS-AMP 文档中的快速图表,显示了训练期间每个解决方案的不同位精度。
优化级别 | 计算 (GEMM) | 通信 | 权重 | 主权重 | 权重梯度 | 优化器状态 |
---|---|---|---|---|---|---|
FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32 |
MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16 |
MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |
TransformersEngine
TransformersEngine
是第一个尝试在 8 位浮点数上进行训练的解决方案。它通过使用模型中某些层的直接替换层来工作,这些层利用其 FP8 引擎来减少位数(例如,从 32 位减少到 8 位),而不会降低模型的最终精度。
具体来说,Accelerate 会找到以下层并用TransformersEngine
版本替换它们
nn.LayerNorm
用于te.LayerNorm
nn.Linear
用于te.Linear
因此,我们最终得到一个模型,其中大多数层处于 BF16,而一些层处于 FP8,从而减少了一些内存。
据观察,在模型中的大多数层都被替换为这两个层之前,使用TransformerEngine
并没有真正开始显示性能提升。因此,只有当参数数量达到数十亿或更多时,更大的模型才显示出性能改进。
TransformerEngine
可以接收许多不同的参数来定制其执行 FP8 计算的方式以及它们的执行方式。下面列出了所有参数。
margin
:用于梯度缩放的边距。interval
:用于重新计算缩放因子频率的间隔。fp8_format
:用于 FP8 配方的格式。必须为HYBRID
或E4M3
之一。(通常HYBRID
用于训练,E4M3
用于评估)amax_history_len
:用于缩放因子计算的历史记录长度。amax_compute_algo
:用于缩放因子计算的算法。必须为max
或most_recent
之一。override_linear_precision
:是否以更高的精度执行fprop
、dgrad
和wgrad
GEMM。
您可以自定义其中的每一个作为 utils.FP8RecipeKwargs 的一部分,以帮助优化模型的性能。
如果我们在前面提到的图表中注意到,TE 只是将计算层转换为 FP8,而其他所有内容都保留在 FP32 中。因此,这最终会使用最多的内存,但好处是保证在训练期间最终精度损失最小。
MS-AMP
MS-AMP 采用与TransformersEngine
不同的方法,提供三个不同的优化级别来将更多操作转换为 FP8 或 FP16。
基本优化级别 (
O1
) 将权重的通信(例如在 DDP 中)传递到 FP8,将模型的权重存储在 FP16 中,并将优化器状态保留在 FP32 中。这种优化级别的主要优势在于,我们可以将通信带宽降低约一半。此外,由于将 1/2 的所有内容转换为 FP8,并且将权重转换为 FP16,因此可以节省更多的 GPU 内存。值得注意的是,优化器状态都保留在 FP32 中。第二优化级别 (
O2
) 在此基础上进行改进,还降低了优化器状态的精度。一个状态在 FP8 中,另一个状态在 FP16 中。通常,这只会带来最终精度没有下降、训练速度提高以及内存减少的净收益,因为现在所有状态都位于 FP16 或 FP8 中。最后,MS-AMP 具有第三个优化级别 (
O3
),它在 DDP 场景(例如 DeepSpeed)中提供帮助。内存中的模型权重完全转换为 FP8,并且主权重现在存储在 FP16 中。这以最大程度地减少内存为目标,因为现在几乎所有内容都位于 FP8 中,只有两个状态保留在 FP16 中。目前,仅支持 DeepSpeed 0.9.2 之前的版本,因此此功能未包含在 Accelerate 集成中。
组合两者
需要进行更多实验,但据观察,将 MS-AMP 和 TransformersEngine 组合起来可以实现最高吞吐量,因为它们依赖于 NVIDIA 优化的 FP8 运算符并利用 MS-AMP 如何减少内存开销。
< > 在 GitHub 上更新