Accelerate 文档
低精度训练方法
并获取增强的文档体验
开始使用
低精度训练方法
新型硬件的发布催生了新的训练范式,可以更好地利用它们。目前,这以 8 位精度训练的形式出现,使用的软件包包括 TransformersEngine (TE) 或 MS-AMP。
为了介绍今天讨论的主题,我们建议查看低精度使用指南,因为本文档将经常引用它。
快速图表
下面是 MS-AMP 文档中的快速图表,显示了训练期间每种解决方案的不同位精度
优化级别 | 计算 (GEMM) | 通信 | 权重 | 主权重 | 权重梯度 | 优化器状态 |
---|---|---|---|---|---|---|
FP16 AMP | FP16 | FP32 | FP32 | 不适用 | FP32 | FP32+FP32 |
Nvidia TE | FP8 | FP32 | FP32 | 不适用 | FP32 | FP32+FP32 |
MS-AMP O1 | FP8 | FP8 | FP16 | 不适用 | FP8 | FP32+FP32 |
MS-AMP O2 | FP8 | FP8 | FP16 | 不适用 | FP8 | FP8+FP16 |
MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |
TransformersEngine
TransformersEngine
是尝试以 8 位浮点进行训练的首个解决方案。它的工作原理是为模型中的某些层使用直接替换层,这些模型利用其 FP8 引擎来减少位数(例如从 32 位减少到 8 位),而不会降低模型的最终准确性。
具体来说,Accelerate 将查找并使用 TransformersEngine
版本替换以下层
nn.LayerNorm
替换为te.LayerNorm
nn.Linear
替换为te.Linear
因此,我们最终得到的模型的大部分层都在 BF16 中,而有些层在 FP8 中,从而减少了一些内存。
有传闻称,我们注意到,在使用 TransformerEngine
时,性能提升实际上并没有开始显现,直到模型中的绝大多数层都由要替换的这两层组成。因此,只有参数数量在数十亿及以上的较大型号才显示出性能改进。
TransformerEngine
可以接收许多不同的参数,这些参数可以自定义其执行 FP8 计算的方式及其功能。下面提供了参数的完整列表
margin
:用于梯度缩放的边距。interval
:用于缩放因子重新计算频率的间隔。fp8_format
:用于 FP8 配方的格式。必须是HYBRID
或E4M3
之一。(通常训练使用HYBRID
,评估使用E4M3
)amax_history_len
:用于缩放因子计算的历史记录长度amax_compute_algo
:用于缩放因子计算的算法。必须是max
或most_recent
之一。override_linear_precision
:是否以更高的精度执行fprop
、dgrad
和wgrad
GEMM。
您可以自定义其中的每一个,作为 utils.FP8RecipeKwargs 的一部分,以帮助优化模型的性能。
如果我们注意到前面提到的图表,TE 只是简单地将计算层转换为 FP8,而其他所有内容都在 FP32 中。因此,这最终会占用最多的内存,但这样做的好处是保证了训练期间最终准确度的损失最小。
MS-AMP
MS-AMP 采用了与 TransformersEngine
不同的方法,它提供了三个不同的优化级别,可以将更多操作转换为 FP8 或 FP16。
基本优化级别 (
O1
) 以 FP8 传递权重通信(例如在 DDP 中),以 FP16 存储模型权重,并将优化器状态保留在 FP32 中。此优化级别的主要好处是我们可以将通信带宽减少一半。此外,由于 1/2 的所有内容都转换为 FP8,并且权重转换为 FP16,因此节省了更多 GPU 内存。值得注意的是,优化器状态仍然保留在 FP32 中。第二级优化 (
O2
) 在此基础上进行了改进,也降低了优化器状态的精度。一个在 FP8 中,另一个在 FP16 中。通常,已经表明,由于现在每个状态都在 FP16 或 FP8 中,因此这将仅提供不降低最终准确度、提高训练速度和减少内存的净收益。最后,MS-AMP 有第三级优化 (
O3
),它在 DDP 场景(如 DeepSpeed)中有所帮助。内存中模型的权重完全转换为 FP8,而主权重现在以 FP16 存储。这最大限度地减少了内存,因为现在不仅几乎所有内容都在 FP8 中,而且只有两个状态保留在 FP16 中。目前,仅支持 0.9.2 及更早版本的 DeepSpeed,因此 Accelerate 集成中不包含此功能
将两者结合
还需要进行更多实验,但据指出,结合 MS-AMP 和 TransformersEngine 可以通过依赖 NVIDIA 优化的 FP8 运算符并利用 MS-AMP 减少内存开销的方式来实现最高吞吐量。
< > 在 GitHub 上更新