梯度同步
PyTorch 的分布式模块通过在系统中的所有 GPU 之间来回通信来运行。这种通信需要时间,并且确保所有进程都知道彼此的状态是在使用 ddp
模块时在特定的触发点发生的。
这些触发点被添加到 PyTorch 模型中,特别是它们的 forward()
和 backward()
方法。当模型使用 DistributedDataParallel
包装时,就会发生这种情况。
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
model = nn.Linear(10, 10)
ddp_model = DistributedDataParallel(model)
在 Accelerate 中,当调用 prepare() 并传入您的模型时,此转换会自动发生。
+ from accelerate import Accelerator
+ accelerator = Accelerator()
import torch.nn as nn
- from torch.nn.parallel import DistributedDataParallel
model = nn.Linear(10,10)
+ model = accelerator.prepare(model)
梯度累积的减速
您现在了解到,PyTorch 在分布式设置中训练时,会在您的 PyTorch 模型的 forward
和 backward
方法中添加钩子。但这如何会导致代码速度变慢呢?
在 DDP(分布式数据并行)中,进程执行和运行的特定顺序在特定点上是预期的,并且这些顺序也必须在大致相同的时间发生,然后才能继续执行。
最直接的例子是当您通过 optimizer.step()
更新模型参数时。在没有梯度累积的情况下,模型的所有实例都需要计算、整理和更新其梯度,然后才能继续处理下一批数据。当执行梯度累积时,您会累积 n
个损失梯度,并在达到 n
批次之前跳过 optimizer.step()
。由于所有训练进程只需要在调用 optimizer.step()
时同步,而不需要修改训练步骤,因此这种不必要的进程间通信会导致明显的减速。
如何避免这种开销呢?
解决减速问题
由于您在处理这些批次时跳过了模型参数更新,因此在实际调用 optimizer.step()
之前,不需要同步它们的梯度。PyTorch 无法自动判断何时需要执行此操作,但它们确实提供了一个工具来帮助您,该工具通过 no_sync
上下文管理器来帮助您,该管理器在将模型转换为 DDP 后添加到您的模型中。
在此上下文管理器下,PyTorch 会在调用 .backward()
时跳过梯度同步,而对 .backward()
的第一次调用将在此上下文管理器之外触发同步。请参阅下面的示例。
ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
for index, batch in enumerate(dataloader):
inputs, targets = batch
# Trigger gradient synchronization on the last batch
if index != (len(dataloader) - 1):
with ddp_model.no_sync():
# Gradients only accumulate
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
else:
# Gradients finally sync
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
optimizer.step()
在 Accelerate 中,为了将其转换为一个无论训练设备是什么都可以调用的 API(虽然如果您不在分布式系统中,它可能不会执行任何操作!),ddp_model.no_sync
被替换为 no_sync() 并以相同的方式运作。
ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
for index, batch in enumerate(dataloader):
inputs, targets = batch
# Trigger gradient synchronization on the last batch
if index != (len(dataloader)-1):
- with ddp_model.no_sync():
+ with accelerator.no_sync(model):
# Gradients only accumulate
outputs = ddp_model(inputs)
loss = loss_func(outputs, targets)
accelerator.backward(loss)
else:
# Gradients finally sync
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
正如您所料,accumulate() 函数通过跟踪当前批次号来封装此条件检查,为您提供最终的梯度累积 API。
ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
for batch in dataloader:
with accelerator.accumulate(model):
optimizer.zero_grad()
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
因此,在 API 选择方面,您应该使用 *accelerator.accumulate
或 accelerator.no_sync
*。
减速究竟有多严重,以及您可能会犯的简单错误
为了建立一个现实的例子,请考虑以下设置。
- 两个单 GPU T4 节点和一个具有两个 GPU 的节点。
- 每个 GPU 都是一个 T4,并且托管在 GCP 上。
- 使用的脚本是 NLP 示例 脚本的修改版本。
- 每个 GPU 的批次大小为 16,并且每 4 步累积一次梯度。
所有脚本都可以在 此存储库 中找到。
如果您不注意梯度同步和 GPU 通信,那么在这些 GPU 在不必要的时间段内相互通信时,可能会浪费大量的时间。
减速程度如何呢?
参考
- 基线:不使用这里讨论的任何同步实践。
no_sync
不正确:仅在backward
调用周围使用no_sync
,而不是在forward
周围使用。no_sync
:正确使用no_sync
模式。accumulate
:正确使用 accumulate()。
以下是针对每个设置在单个节点和双节点设置上对 29 批数据进行迭代的每批次平均秒数。
基线 | no_sync 不正确 | no_sync | accumulate | |
---|---|---|---|---|
多节点 | 2±0.01 秒 | 2.13±0.08 秒 | 0.91±0.11 秒 | 0.91±0.11 秒 |
单个节点 | 0.50±0.01 秒 | 0.50±0.01 秒 | 0.41±0.015 秒 | 0.41±0.015 秒 |
如您所见,如果您不注意梯度同步的设置方式,那么在训练期间可能会出现超过 2 倍的减速!
如果您担心确保所有操作都正确执行,我们强烈建议您使用 accumulate() 函数,并将 gradient_accumulation_steps
或 gradient_accumulation_plugin
传递给 Accelerator 对象,以便 Accelerate 可以为您处理此操作。
使用 FSDP 时,no_sync 需要额外的 GPU 内存
请注意,在执行 FSDP 训练时,不同步梯度可能会产生负面影响。正如 torch
中的警告,FSDP 的 no_sync
上下文管理器 将需要额外的内存。
因此,在使用 FSDP 时,在内存密集型情况下,我们建议在 GradientAccumulationPlugin 中将 sync_each_batch
设置为 True
,以禁用 no_sync
。
请参阅以下示例,我们使用 8 个 A100-80GB GPU 对 Mixtral(470 亿个参数)进行微调。我们发现,即使对于适度的 gradient_accumulation_steps=2
,如果启用 no_sync
,我们也会很快出现内存不足 (OOM) 错误。同样,这是由于 FSDP 的 no_sync
导致了额外的内存开销。但是,如果通过 sync_each_batch=True
禁用 no_sync
,那么 gradient_accumulation_steps=16
的内存消耗将恢复到 gradient_accumulation_steps=1
的水平。
模型 | no_sync (accum=1) | no_sync (accum=2) | no_sync 禁用 (accum=16) |
---|---|---|---|
mixtral 8x7B | 69G | OOM | 69G |
禁用 no_sync
意味着由于额外的 数据同步,会出现减速,如本指南前面部分所述。