Accelerate 文档

低精度训练方法

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

低精度训练方法

新型硬件的发布催生了能更好地利用这些硬件的新训练范式。目前,这体现为使用诸如 TransformersEngine (TE) 或 MS-AMP 等软件包进行 8 位精度训练。

为了介绍今天讨论的主题,我们建议您回顾低精度使用指南,因为本文档将经常引用它。

快速图表

以下是 MS-AMP 文档中的一个快速图表,显示了训练期间每种解决方案的不同位精度。

优化级别 计算(GEMM) 通信 重量 主权重 权重梯度 优化器状态
FP16 AMP FP16 FP32 FP32 不适用 FP32 FP32+FP32
Nvidia TE FP8 FP32 FP32 不适用 FP32 FP32+FP32
MS-AMP O1 FP8 FP8 FP16 不适用 FP8 FP32+FP32
MS-AMP O2 FP8 FP8 FP16 不适用 FP8 FP8+FP16
MS-AMP O3 FP8 FP8 FP8 FP16 FP8 FP8+FP16

TransformersEngine

TransformersEngine 是第一个尝试用 8 位浮点数进行训练的解决方案。它的工作原理是为模型中的某些层使用可直接替换的层,利用其 FP8 引擎来减少位数(例如从 32 位减少到 8 位),而不会降低模型的最终精度。

具体来说,Accelerate 会查找并用 TransformersEngine 版本替换以下层

  • nn.LayerNorm 替换为 te.LayerNorm
  • nn.Linear 替换为 te.Linear

结果,我们得到的模型中,大部分层是 BF16,而一些层是 FP8,从而减少了部分内存占用。

根据经验,我们注意到,在使用 TransformerEngine 时,只有当模型中绝大多数层都是由这两个可替换的层构成时,性能提升才真正开始显现。因此,只有当参数数量达到数十亿级别或更高的大型模型才显示出性能改进。

TransformerEngine 可以接收许多不同的参数,以自定义其执行 FP8 计算的方式和功能。下面是完整的参数列表:

  • margin:用于梯度缩放的边距。
  • interval:用于重新计算缩放因子的频率间隔。
  • fp8_format``:用于 FP8 recipe 的格式。必须是 HYBRIDE4M3 之一。(通常 HYBRID 用于训练,E4M3` 用于评估)
  • amax_history_len:用于缩放因子计算的历史记录长度。
  • amax_compute_algo:用于缩放因子计算的算法。必须是 maxmost_recent 之一。
  • override_linear_precision:是否以更高精度执行 fpropdgradwgrad GEMM 运算。

您可以将这些参数作为 utils.FP8RecipeKwargs 的一部分进行自定义,以帮助优化模型的性能。

如果我们注意前面提到的图表,TE 只是将计算层转换为 FP8,而其他所有部分都保持 FP32。因此,这最终会占用最多的内存,但其好处是保证了训练期间最终精度的损失最小。

MS-AMP

MS-AMP 采用了与 TransformersEngine 不同的方法,它提供了三种不同的优化级别,以便将更多的操作转换为 FP8 或 FP16。

  • 基础优化级别(O1)以 FP8 格式传递权重通信(例如在 DDP 中),将模型权重存储为 FP16,并将优化器状态保留为 FP32。这个优化级别的主要好处是我们可以将通信带宽减少大约一半。此外,由于一半的数据被转换为 FP8,权重被转换为 FP16,从而节省了更多的 GPU 内存。值得注意的是,两个优化器状态都保持为 FP32。

  • 第二个优化级别(O2)在此基础上进一步改进,通过降低优化器状态的精度。一个状态是 FP8,另一个是 FP16。通常情况下,这已被证明只会在不降低最终精度的情况下带来净收益,同时提高了训练速度,减少了内存占用,因为现在每个状态要么是 FP16,要么是 FP8。

  • 最后,MS-AMP 还有一个第三优化级别(O3),它在 DDP 场景(如 DeepSpeed)中有所帮助。内存中的模型权重完全转换为 FP8,主权重现在存储为 FP16。这最大限度地减少了内存占用,因为现在不仅几乎所有东西都是 FP8,只剩下两个状态是 FP16。目前,只支持到 0.9.2 版本的 DeepSpeed,因此该功能未包含在 Accelerate 集成中。

两者结合

还需要进行更多实验,但有人指出,结合 MS-AMP 和 TransformersEngine 可能会通过依赖 NVIDIA 优化的 FP8 算子并利用 MS-AMP 减少内存开销的方式,实现最高的吞吐量。

< > 在 GitHub 上更新